如何使用自定义生成器生成 tf.data.Dataset.from_generator 批次

How to make tf.data.Dataset.from_generator yield batches with a custom generator

我想用tf.dataAPI。我预期的工作流程如下所示:

我使用 tf.data.from_generator 函数创建了一个迭代器。后来我做了一个可初始化的迭代器。

我的代码看起来像这样:

def custom_gen():
   img = np.random.normal((width, height, channels, frames))
   yield(img, img) # I train an autoencoder, so the x == y`

dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()

sess = tf.Session()
sess.run(iter.get_next())

我希望 iter.get_next() 为我提供一个具有批量大小的 5D 张量。但是,我什至尝试在我自己的 custom_generator 中产生批量大小,但它不起作用。当我想用输入形状 (batch_size, width, height, channels, frames).

的占位符初始化数据集时,我遇到了一个错误

该示例中的 Dataset 构造过程格式错误。它应该按照这个顺序完成,正如 Importing Data:

上的官方指南所建立的那样
  1. 应调用基本数据集创建函数或静态方法来建立数据的原始(例如静态方法from_slice_tensorsfrom_generatorlist_files, ...).
  2. 此时,转换可以通过链接适配器方法(例如batch)应用。

因此:

dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)