Tensorflow 数据集交错 from_generator 抛出 InvalidArgumentError
Tensorflow dataset interleave from_generator throws InvalidArgumentError
我有一个发电机,我正在尝试交错:
def hello(i):
for j in tf.range(i):
yield j
ds = tf.data.Dataset.range(10).interleave(
lambda ind: tf.data.Dataset.from_generator(lambda: hello(ind), output_types=(tf.int32,)))
for x in ds.take(1):
print(x)
但是我得到这个错误:
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: args_0:0
[[{{node PyFunc}}]]
张量流版本:2.3.2
问题在于您构建生成器函数的方式。您应该使用 args
关键字参数来指定传递给生成器函数的参数,而不是使用 lambda
。
ds = tf.data.Dataset.range(10).interleave(
lambda ind: tf.data.Dataset.from_generator(
hello, args=(ind,), output_types=tf.int32
)
)
对于 TF2.4,请注意您应该使用 output_signature
而不是 output_types
,因为后者已被弃用。 (在那种情况下 output_signature=tf.TensorSpec(shape=(), dtype=tf.int32,)
)。
我有一个发电机,我正在尝试交错:
def hello(i):
for j in tf.range(i):
yield j
ds = tf.data.Dataset.range(10).interleave(
lambda ind: tf.data.Dataset.from_generator(lambda: hello(ind), output_types=(tf.int32,)))
for x in ds.take(1):
print(x)
但是我得到这个错误:
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: args_0:0
[[{{node PyFunc}}]]
张量流版本:2.3.2
问题在于您构建生成器函数的方式。您应该使用 args
关键字参数来指定传递给生成器函数的参数,而不是使用 lambda
。
ds = tf.data.Dataset.range(10).interleave(
lambda ind: tf.data.Dataset.from_generator(
hello, args=(ind,), output_types=tf.int32
)
)
对于 TF2.4,请注意您应该使用 output_signature
而不是 output_types
,因为后者已被弃用。 (在那种情况下 output_signature=tf.TensorSpec(shape=(), dtype=tf.int32,)
)。