如何连接两个 Tensorflow 数据集?

How to concatenate two Tensorflow DataSets?

我正在尝试加载然后扩充一些图像 (160 x 160 x 3) 数据集,其中图像存储在文件夹中,文件夹名称是我的标签。正在应用多个转换来生成数据副本,它们需要 concatenated (or stacked may be) 才能合并数据并将它们存储回磁盘。

下面是我能写的最简单的可重现片段,但我无法append/concatenate/stack这两个数据集。

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
transformed_userA_w_label.concatenate(transformed_userB_w_label)

打印语句输出如下:

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

预期:6 图片

Output ds - <ConcatenateDataset shapes: ((6, 160, 160, 3), (6, 2)), types: (tf.float64, tf.float64)>

这里的关键问题是 tf.data.Dataset.from_tensorstf.data.Dataset.from_tensor_slices 的使用。

  • tf.data.Dataset.from_tensors([t1,t2,t3]) - 创建一个数据集,其中列表的每个元素都作为数据点给出
  • tf.data.Dataset.from_tensor_slices(t) - 创建一个数据集,其中一个元素是在第一个轴上索引的一个项目

根据您拥有的数据(即 3 张尺寸为 160x160x3 的图像,即 3x160x160x3 ),您需要使用第二种方法。否则,您的所有 3 张图像都将作为一个数据点(这可能不是您想要的)。

转到第二个问题,你显示的输出,

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

它只是显示单个元素的外观。因此,即使代码正确,您也不会看到 6。要查看您必须迭代数据集的元素数量。在您的情况下,您会看到 2(因为此数据集将所有 3 个图像视为单个数据点)。

因此,要修复您的代码,请执行此操作,

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)

for i,ii in enumerate(concat_ds):
  print(i)

您将看到 i 的值被打印 6 次。这就是你需要的。