Tensorflow:如何手动分片数据集

Tensorflow: how to manually shard a dataset

我正在使用 MirroredStrategy 执行多 GPU 训练,它似乎没有正确分片数据。您如何手动分片数据?

我知道我可以对 tf.data 数据集使用 shard 方法,但为此我需要访问工作人员 ID,但我不知道如何获取它。我如何访问工作人员 ID?

MirroredStrategy 在单个 worker 上运行(对于多个 worker 有 MultiWorkerMirroredStrategy)。因为它只在一个 worker 上运行,所以 MirroredStrategy 运行单个 Dataset 管道,没有任何数据分片。在每个步骤中,MirroredStrategy 为每个工作人员请求一个数据集元素。