有什么方法可以连接 3 个或更多 tf.data.Dataset
Is there any ways to concatenate 3 or more tf.data.Dataset
我想在 TensorFlow 中连接 3 个或更多数据集。
要连接 2 个数据集,
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)
但是,这样的话,3个或更多的数据集是不能拼接的。
所以我想这样做
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)
有什么办法吗?
在这个具体的例子中你可以这样做
concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
请注意,您必须将 concatenate
的结果分配给一个新变量!它不能就地运行。
当然,如果你有很多数据集,这并不能很好地扩展,但这应该可行:
datasets = [dataset1, dataset2, dataset3] # can be more than 3 of course
concat_dataset = datasets[0]
for dset in datasets[1:]:
concat_dataset = concat_dataset.concatenate(dset)
import tensorflow as tf
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
def func(*datasets):
out = {}
for dataset in datasets:
for key in dataset:
if key in out:
_value = out[key]
out[key] = tf.concat([_value, dataset[key]], axis=-1)
else:
out[key] = dataset[key]
return out
tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)
我想在 TensorFlow 中连接 3 个或更多数据集。 要连接 2 个数据集,
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)
但是,这样的话,3个或更多的数据集是不能拼接的。 所以我想这样做
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)
有什么办法吗?
在这个具体的例子中你可以这样做
concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
请注意,您必须将 concatenate
的结果分配给一个新变量!它不能就地运行。
当然,如果你有很多数据集,这并不能很好地扩展,但这应该可行:
datasets = [dataset1, dataset2, dataset3] # can be more than 3 of course
concat_dataset = datasets[0]
for dset in datasets[1:]:
concat_dataset = concat_dataset.concatenate(dset)
import tensorflow as tf
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
def func(*datasets):
out = {}
for dataset in datasets:
for key in dataset:
if key in out:
_value = out[key]
out[key] = tf.concat([_value, dataset[key]], axis=-1)
else:
out[key] = dataset[key]
return out
tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)