Tensorflow model.save() 不保存模型的推理函数
Tensorflow model.save() doesn't save the model's inference function
我有一个 class,它扩展了 keras.Model
并实现了一个由多个标准层和自定义层组成的模型。
我没有使用 model.fit()
而是使用 for
循环来迭代数据和 运行 以下训练步骤函数
@tf.function
def train_step(batch):
with tf.GradientTape() as tape:
inputs = batch[0] + [batch[1][-1]]
predictions, _ = model(inputs, training=True)
loss = log_likelihood(batch[1], predictions, batch[2])
regularisation_loss = model.output_layers.losses
gradients = tape.gradient([regularisation_loss, loss], model.trainable_variables)
optimiser.apply_gradients(zip(gradients, model.trainable_variables))
为了节省,我只是调用model.save(model_path)
。
一切顺利,直到我尝试重新加载保存的模型。加载调用抛出(只是底线)
File "..\.conda\envs\tf.2\lib\site-packages\tensorflow\python\saved_model\function_deserialization.py", line 265, in recreate_function
concrete_function_objects.append(concrete_functions[concrete_function_name])
KeyError: '__inference_model_layer_call_fn_37936'
我假设,这指的是我的模型的主要 call
函数,它只是出于某种原因未被序列化和保存。我查看了 concrete_functions
,除了这个函数外几乎所有东西都在里面。我在使用和不使用 tf.function
装饰器的情况下进行了尝试,没有发现任何区别。
我现在有点迷路了,真的很感激不仅是一个答案,而且是一个很好的调试方向。
谢谢
所以在讨论之后
https://github.com/tensorflow/tensorflow/issues/42004
这让我朝着正确的方向前进,我做了以下似乎解决了问题的更改
- 除了“主模型包装器”之外的所有模型实现都是对 Layer 对象的更改。
- 所有自定义实现(图层和模型)都被赋予了一个 get_config 函数。
我不确定哪一个成功了,但我相信第一点。
我有一个 class,它扩展了 keras.Model
并实现了一个由多个标准层和自定义层组成的模型。
我没有使用 model.fit()
而是使用 for
循环来迭代数据和 运行 以下训练步骤函数
@tf.function
def train_step(batch):
with tf.GradientTape() as tape:
inputs = batch[0] + [batch[1][-1]]
predictions, _ = model(inputs, training=True)
loss = log_likelihood(batch[1], predictions, batch[2])
regularisation_loss = model.output_layers.losses
gradients = tape.gradient([regularisation_loss, loss], model.trainable_variables)
optimiser.apply_gradients(zip(gradients, model.trainable_variables))
为了节省,我只是调用model.save(model_path)
。
一切顺利,直到我尝试重新加载保存的模型。加载调用抛出(只是底线)
File "..\.conda\envs\tf.2\lib\site-packages\tensorflow\python\saved_model\function_deserialization.py", line 265, in recreate_function
concrete_function_objects.append(concrete_functions[concrete_function_name])
KeyError: '__inference_model_layer_call_fn_37936'
我假设,这指的是我的模型的主要 call
函数,它只是出于某种原因未被序列化和保存。我查看了 concrete_functions
,除了这个函数外几乎所有东西都在里面。我在使用和不使用 tf.function
装饰器的情况下进行了尝试,没有发现任何区别。
我现在有点迷路了,真的很感激不仅是一个答案,而且是一个很好的调试方向。
谢谢
所以在讨论之后 https://github.com/tensorflow/tensorflow/issues/42004 这让我朝着正确的方向前进,我做了以下似乎解决了问题的更改
- 除了“主模型包装器”之外的所有模型实现都是对 Layer 对象的更改。
- 所有自定义实现(图层和模型)都被赋予了一个 get_config 函数。
我不确定哪一个成功了,但我相信第一点。