有没有办法在 tf.data 管道中使用 tf.keras.model.predict?
Is there a way to use tf.keras.model.predict within a tf.data pipeline?
我有一个经过训练的模型,我想在 tf.data
管道中为第二个模型使用它。当我尝试执行此操作时,我收到一条 ValueError: Unknown graph. Aborting.
我不知道如何处理此错误消息。
我的代码看起来像这样:
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list)
images = files.map(load_image_from_file)
def pass_image_through_model(img):
return model.predict(img, steps=1)
dataset = images.map(pass_image_through_model)
return dataset
这有什么问题?我得到的错误是:
/home/.../code/dataloader.py:236 pass_image_through_model *
return model.predict(img, steps=1)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
use_multiprocessing=use_multiprocessing)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
callbacks=callbacks)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
f = _make_execution_function(model, mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
return model._make_execution_function(mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
self._make_predict_function()
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
**kwargs)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
raise ValueError('Unknown graph. Aborting.')
ValueError: Unknown graph. Aborting.
如果这是您第一次处理 tf.data.Dataset()
对象,您得到的错误可能是无声的。
tf.data.Dataset()
中的所有操作实际上都是以图形模式执行的,您不能使用tf.*
中预定义的函数之外的任何函数。
将任意 Python 代码与 tf.data.Dataset()
混合的唯一方法是使用 tf.py_function()
,否则将抛出错误。
请记住,将 Python 代码与优化的 tf.data.Dataset()
代码混合使用会导致时间性能下降。
唯一的测试方法是检索数据集,使用 as_numpy_iterator()
获取数据并使用模型进行预测,因此在映射过程之外。
解决此问题的最简单方法之一是将输入直接传递给模型,而不是使用 model.predit
方法。这样做的原因是model.predict
returns一个numpy.ndarray
。这会导致错误,因为 tf.data
使用图形执行,这意味着最好在该图形中输入和输出张量的任何操作。
下面是一个快速的工作示例。
import tensorflow as tf
# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)
def map_fn(row):
return model(row)
# Create some input data
a = tf.constant([1, 2])
# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))
for el in model_mapped_ds:
print(el)
最后,下面是您的使用情况。
def pass_image_through_model(img):
return model(img) # this returns a tensor
@tf.function
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
images = files.map(load_image_from_file)
dataset = images.map(pass_image_through_model)
return dataset
我有一个经过训练的模型,我想在 tf.data
管道中为第二个模型使用它。当我尝试执行此操作时,我收到一条 ValueError: Unknown graph. Aborting.
我不知道如何处理此错误消息。
我的代码看起来像这样:
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list)
images = files.map(load_image_from_file)
def pass_image_through_model(img):
return model.predict(img, steps=1)
dataset = images.map(pass_image_through_model)
return dataset
这有什么问题?我得到的错误是:
/home/.../code/dataloader.py:236 pass_image_through_model *
return model.predict(img, steps=1)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
use_multiprocessing=use_multiprocessing)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
callbacks=callbacks)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
f = _make_execution_function(model, mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
return model._make_execution_function(mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
self._make_predict_function()
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
**kwargs)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
raise ValueError('Unknown graph. Aborting.')
ValueError: Unknown graph. Aborting.
如果这是您第一次处理 tf.data.Dataset()
对象,您得到的错误可能是无声的。
tf.data.Dataset()
中的所有操作实际上都是以图形模式执行的,您不能使用tf.*
中预定义的函数之外的任何函数。
将任意 Python 代码与 tf.data.Dataset()
混合的唯一方法是使用 tf.py_function()
,否则将抛出错误。
请记住,将 Python 代码与优化的 tf.data.Dataset()
代码混合使用会导致时间性能下降。
唯一的测试方法是检索数据集,使用 as_numpy_iterator()
获取数据并使用模型进行预测,因此在映射过程之外。
解决此问题的最简单方法之一是将输入直接传递给模型,而不是使用 model.predit
方法。这样做的原因是model.predict
returns一个numpy.ndarray
。这会导致错误,因为 tf.data
使用图形执行,这意味着最好在该图形中输入和输出张量的任何操作。
下面是一个快速的工作示例。
import tensorflow as tf
# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)
def map_fn(row):
return model(row)
# Create some input data
a = tf.constant([1, 2])
# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))
for el in model_mapped_ds:
print(el)
最后,下面是您的使用情况。
def pass_image_through_model(img):
return model(img) # this returns a tensor
@tf.function
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
images = files.map(load_image_from_file)
dataset = images.map(pass_image_through_model)
return dataset