有什么方法可以连接 3 个或更多 tf.data.Dataset

Is there any ways to concatenate 3 or more tf.data.Dataset

我想在 TensorFlow 中连接 3 个或更多数据集。 要连接 2 个数据集,

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)

但是,这样的话,3个或更多的数据集是不能拼接的。 所以我想这样做

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)

有什么办法吗?

在这个具体的例子中你可以这样做

concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)

请注意,您必须将 concatenate 的结果分配给一个新变量!它不能就地运行。

当然,如果你有很多数据集,这并不能很好地扩展,但这应该可行:

datasets = [dataset1, dataset2, dataset3]  # can be more than 3 of course

concat_dataset = datasets[0]
for dset in datasets[1:]:
    concat_dataset = concat_dataset.concatenate(dset)
import tensorflow as tf

dataset1 = tf.data.Dataset.range(1, 4)

dataset2 = tf.data.Dataset.range(4, 8)

dataset3 = tf.data.Dataset.range(8, 12)

 
def func(*datasets):

    out = {}

    for dataset in datasets:

        for key in dataset:

            if key in out:

                _value = out[key]

                out[key] = tf.concat([_value, dataset[key]], axis=-1)

            else:

                out[key] = dataset[key]

    return out


tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)