Tensorflow 2.0:在多输入场景中构造“tf.data.Dataset”输出的最佳方式

Tensorflow 2.0: Best way for structure the output of `tf.data.Dataset` in multiple inputs scenario

我正在 Tensorflow 上构建一个用于图像去模糊的 GAN,它是 DeblurGANv2 的一个实现。我将 GAN 设置为具有两个输入、一批模糊图像和一批清晰图像的方式。按照这一行,我将输入设计为具有两个键 ['sharp', 'blur'] 的 Python 字典,每个键都有一个形状为 [batch_size, 512, 512, 3] 的张量,这使得将模糊图像批处理到生成器,然后将生成器的输出和锐化图像批量提供给鉴别器。

根据最后的要求,我创建了一个 tf.data.Dataset 来精确输出,一个包含两个张量的字典,每个张量都有它们的批次维度。这与我的 GAN 实现完美互补,一切都运行良好。

所以请记住,我的输入不是张量,而是 python 字典,它没有批次维度,这将与稍后解释我的问题有关。

最近,我决定使用 Tensorflow Distribution Strategies 添加对分布式训练的支持。 Tensorflow 的这个特性允许在多个设备上分发训练,包括在多台机器上。一些实现有一个特性,例如 MirroredStrategy,它获取输入张量,将其分成相等的部分,并将每个切片提供给不同的设备,这意味着,如果您的批处理大小为 16 并且4 个 GPU,每个 GPU 将结束本地批次的 4 个数据点,在此之后有一些魔法可以聚合结果和其他与我的问题无关的东西。

正如您已经注意到的那样,将张量作为输入或至少某种具有外部批量维度的输入对于分布策略至关重要,而我拥有的是一个 Python 字典,其中内部字典张量值中输入的批处理维度。这是个大问题,我目前的实现与分布式训练不兼容。

我一直在寻找解决方法,但我不能很好地解决这个问题,也许只是让输入成为 shape=[batch_size, 2, 512, 512, 3] 的巨大张量并将其切片?不确定这只是我现在想到的,哈哈。无论如何,我认为这非常模棱两可,我无法区分这两个输入,至少不能区分字典键的清晰度。编辑:这个解决方案的问题是使我的数据集转换非常昂贵,因此使数据集吞吐量变慢,考虑到这是一个图像加载管道,这是一个重点。

也许我对分布式策略如何工作的解释不是最严谨的,如果我没有看到任何东西请随时纠正我。

PD:这不是bug问题或代码错误,主要是"System Design Query",希望这不是违法的

代替使用字典作为 GAN 的输入,您可以尝试按以下方式映射函数,

def load_image(fileA,fileB):
    imageA = tf.io.read_file(fileA)
    imageA = tf.image.decode_jpeg(imageA, channels=3)

    imageB = tf.io.read_file(fileB)
    imageB = tf.image.decode_jpeg(imageB)
    return imageA,imageB

trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)

#for mirrored strategy

dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)

您可以通过传递两个图像来迭代数据集并更新网络。
希望对您有所帮助!