如何连接两个 Tensorflow 数据集?
How to concatenate two Tensorflow DataSets?
我正在尝试加载然后扩充一些图像 (160 x 160 x 3
) 数据集,其中图像存储在文件夹中,文件夹名称是我的标签。正在应用多个转换来生成数据副本,它们需要 concatenated (or stacked may be)
才能合并数据并将它们存储回磁盘。
下面是我能写的最简单的可重现片段,但我无法append/concatenate/stack
这两个数据集。
def some_transformation(image, label):
# do something like rotation, clipping, noise add etc.
return image, label
userA = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)
userB = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)
print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
transformed_userA_w_label.concatenate(transformed_userB_w_label)
打印语句输出如下:
User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
预期:6
图片
Output ds - <ConcatenateDataset shapes: ((6, 160, 160, 3), (6, 2)), types: (tf.float64, tf.float64)>
这里的关键问题是 tf.data.Dataset.from_tensors
与 tf.data.Dataset.from_tensor_slices
的使用。
tf.data.Dataset.from_tensors([t1,t2,t3])
- 创建一个数据集,其中列表的每个元素都作为数据点给出
tf.data.Dataset.from_tensor_slices(t)
- 创建一个数据集,其中一个元素是在第一个轴上索引的一个项目
根据您拥有的数据(即 3 张尺寸为 160x160x3 的图像,即 3x160x160x3
),您需要使用第二种方法。否则,您的所有 3 张图像都将作为一个数据点(这可能不是您想要的)。
转到第二个问题,你显示的输出,
User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
它只是显示单个元素的外观。因此,即使代码正确,您也不会看到 6
。要查看您必须迭代数据集的元素数量。在您的情况下,您会看到 2
(因为此数据集将所有 3 个图像视为单个数据点)。
因此,要修复您的代码,请执行此操作,
def some_transformation(image, label):
# do something like rotation, clipping, noise add etc.
return image, label
userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)
userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)
print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)
for i,ii in enumerate(concat_ds):
print(i)
您将看到 i
的值被打印 6 次。这就是你需要的。
我正在尝试加载然后扩充一些图像 (160 x 160 x 3
) 数据集,其中图像存储在文件夹中,文件夹名称是我的标签。正在应用多个转换来生成数据副本,它们需要 concatenated (or stacked may be)
才能合并数据并将它们存储回磁盘。
下面是我能写的最简单的可重现片段,但我无法append/concatenate/stack
这两个数据集。
def some_transformation(image, label):
# do something like rotation, clipping, noise add etc.
return image, label
userA = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)
userB = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)
print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
transformed_userA_w_label.concatenate(transformed_userB_w_label)
打印语句输出如下:
User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
预期:6
图片
Output ds - <ConcatenateDataset shapes: ((6, 160, 160, 3), (6, 2)), types: (tf.float64, tf.float64)>
这里的关键问题是 tf.data.Dataset.from_tensors
与 tf.data.Dataset.from_tensor_slices
的使用。
tf.data.Dataset.from_tensors([t1,t2,t3])
- 创建一个数据集,其中列表的每个元素都作为数据点给出tf.data.Dataset.from_tensor_slices(t)
- 创建一个数据集,其中一个元素是在第一个轴上索引的一个项目
根据您拥有的数据(即 3 张尺寸为 160x160x3 的图像,即 3x160x160x3
),您需要使用第二种方法。否则,您的所有 3 张图像都将作为一个数据点(这可能不是您想要的)。
转到第二个问题,你显示的输出,
User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
它只是显示单个元素的外观。因此,即使代码正确,您也不会看到 6
。要查看您必须迭代数据集的元素数量。在您的情况下,您会看到 2
(因为此数据集将所有 3 个图像视为单个数据点)。
因此,要修复您的代码,请执行此操作,
def some_transformation(image, label):
# do something like rotation, clipping, noise add etc.
return image, label
userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)
userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)
print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)
for i,ii in enumerate(concat_ds):
print(i)
您将看到 i
的值被打印 6 次。这就是你需要的。