在 tf.data 中切片会导致 "iterating over `tf.Tensor` is not allowed in Graph execution" 错误
Slicing in tf.data causes "iterating over `tf.Tensor` is not allowed in Graph execution" error
我有一个如下创建的数据集,其中 image_train_path
是图像文件路径列表,
例如。 [b'/content/drive/My Drive/data/folder1/im1.png', b'/content/drive/My Drive/data/folder2/im6.png',...]
。我需要提取文件夹路径,例如 '/content/drive/My Drive/data/folder1'
,然后进行一些其他操作。我尝试使用 preprocessData
函数执行此操作,如下所示。
dataset = tf.data.Dataset.from_tensor_slices(image_train_path)
dataset = dataset.map(preprocessData, num_parallel_calls=16)
其中 preprocessData
是:
def preprocessData(images_path):
folder=tf.strings.split(images_path,'/')
foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
....
但是,切片线导致如下错误:
OperatorNotAllowedInGraphError: in user code:
<ipython-input-21-2a9827982c16>:4 preprocessData *
foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:210 wrapper **
result = dispatch(wrapper, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
result = dispatcher.handle(args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ragged/ragged_dispatch.py:130 handle
for elt in x:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:524 __iter__
self._disallow_iteration()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 _disallow_iteration
self._disallow_in_graph_mode("iterating over `tf.Tensor`")
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:500 _disallow_in_graph_mode
" this function with @tf.function.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
我在 Tf2.4 和 tf nightly 中都试过了。我尝试使用 @tf.function
和 tf.data.experimental.enable_debug_mode()
进行装饰。总是报同样的错误。
我不太明白是哪个部分导致了 'iteration',但我猜问题出在切片上。是否有替代方法来完成此操作?
函数 tf.strings.join
需要张量列表,如文档所述:
Args
inputs: A list of tf.Tensor objects of same size and tf.string dtype.
tf.slice
returns 一个 Tensor,然后 join 函数会尝试迭代它,导致错误。
您可以通过简单的列表理解来提供函数:
def preprocessData(images_path):
folder=tf.strings.split(images_path,'/')
foldername=tf.strings.join([folder[i] for i in range(6)],"/")
return foldername
我有一个如下创建的数据集,其中 image_train_path
是图像文件路径列表,
例如。 [b'/content/drive/My Drive/data/folder1/im1.png', b'/content/drive/My Drive/data/folder2/im6.png',...]
。我需要提取文件夹路径,例如 '/content/drive/My Drive/data/folder1'
,然后进行一些其他操作。我尝试使用 preprocessData
函数执行此操作,如下所示。
dataset = tf.data.Dataset.from_tensor_slices(image_train_path)
dataset = dataset.map(preprocessData, num_parallel_calls=16)
其中 preprocessData
是:
def preprocessData(images_path):
folder=tf.strings.split(images_path,'/')
foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
....
但是,切片线导致如下错误:
OperatorNotAllowedInGraphError: in user code:
<ipython-input-21-2a9827982c16>:4 preprocessData *
foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:210 wrapper **
result = dispatch(wrapper, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
result = dispatcher.handle(args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ragged/ragged_dispatch.py:130 handle
for elt in x:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:524 __iter__
self._disallow_iteration()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 _disallow_iteration
self._disallow_in_graph_mode("iterating over `tf.Tensor`")
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:500 _disallow_in_graph_mode
" this function with @tf.function.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
我在 Tf2.4 和 tf nightly 中都试过了。我尝试使用 @tf.function
和 tf.data.experimental.enable_debug_mode()
进行装饰。总是报同样的错误。
我不太明白是哪个部分导致了 'iteration',但我猜问题出在切片上。是否有替代方法来完成此操作?
函数 tf.strings.join
需要张量列表,如文档所述:
Args
inputs: A list of tf.Tensor objects of same size and tf.string dtype.
tf.slice
returns 一个 Tensor,然后 join 函数会尝试迭代它,导致错误。
您可以通过简单的列表理解来提供函数:
def preprocessData(images_path):
folder=tf.strings.split(images_path,'/')
foldername=tf.strings.join([folder[i] for i in range(6)],"/")
return foldername