如何从仅输出数组列表的生成器开发 tf.data 对象?
How to develop a tf.data object from a generator that only outputs a list of arrays?
我正在尝试开发一个生成数组列表的 tf.data 对象,但出现数据不匹配错误。这是我的尝试
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )
list(Labeldataset.take(1))
这是我得到的错误
InvalidArgumentError: TypeError: generator
yielded an element that did not match the expected structure. The expected structure was (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), but the yielded element was (, , , ).
Traceback (most recent call last):
首先,.from_generator代码中的项目数量不匹配。
其次,调用生成器应该不带()。
这是在 TF 2.1 中测试的工作代码。
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, # without ()
(tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
(tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape
list(Labeldataset.take(1))
结果:
[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]
我正在尝试开发一个生成数组列表的 tf.data 对象,但出现数据不匹配错误。这是我的尝试
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )
list(Labeldataset.take(1))
这是我得到的错误
InvalidArgumentError: TypeError:
generator
yielded an element that did not match the expected structure. The expected structure was (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), but the yielded element was (, , , ). Traceback (most recent call last):
首先,.from_generator代码中的项目数量不匹配。 其次,调用生成器应该不带()。 这是在 TF 2.1 中测试的工作代码。
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, # without ()
(tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
(tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape
list(Labeldataset.take(1))
结果:
[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]