如何从仅输出数组列表的生成器开发 tf.data 对象?

How to develop a tf.data object from a generator that only outputs a list of arrays?

我正在尝试开发一个生成数组列表的 tf.data 对象,但出现数据不匹配错误。这是我的尝试

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
     labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )

list(Labeldataset.take(1))

这是我得到的错误

InvalidArgumentError: TypeError: generator yielded an element that did not match the expected structure. The expected structure was (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), but the yielded element was (, , , ). Traceback (most recent call last):

首先,.from_generator代码中的项目数量不匹配。 其次,调用生成器应该不带()。 这是在 TF 2.1 中测试的工作代码。

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
    labelGen, # without ()
    (tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
    (tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape

list(Labeldataset.take(1))

结果:

[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]