在 `predict_step` 上的 Keras 模型中禁用了 Eager Execution

Eager execution disabled in Keras model on `predict_step`

为什么 tensorflow 在 tf.keras.Modelpredict_step 函数中禁用急切执行?也许我弄错了什么,但这里有一个例子:

from __future__ import annotations
from functools import wraps
import tensorflow as tf

def print_execution(func):
    @wraps(func)
    def wrapper(self: SimpleModel, data):
        print(tf.executing_eagerly())  # Prints False
        return func(self, data)
    return wrapper

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None, mask=None):
        return inputs

    @print_execution
    def predict_step(self, data):
        return super().predict_step(data)

if __name__ == "__main__":
    x = tf.random.uniform((2, 2))
    print(tf.executing_eagerly())  # Prints True
    model = SimpleModel()
    pred = model.predict(x)

这是预期的行为吗?有没有办法在急切模式下强制 predict_step 到 运行?

如果你想运行 eager模式下的predict_step函数,你可以按如下方式进行。请注意,它会将所有内容设置为急切模式。

import tensorflow as tf
tf.config.run_functions_eagerly(True)

通常 tf.function 处于 Graph 模式。使用上面的语句,它们也可以设置为 Eager 模式, src.

根据您的评论,AFAIK,如果您在编译模型时设置 run_eagerly 应该没有任何区别。这里是来自官方的说法,src - model.compile.

run_eagerly: Bool. Defaults to False. If True, this Model's logic will not be wrapped in a tf. function. Recommended to leave this as None unless your Model cannot be run inside a tf. function.


关于您的第一个查询,为什么 TensorFlowtf.keras.Modelpredict_step 函数中禁用急切执行?

主要原因之一是提供模型的最佳性能。不仅是 predict_step,还有 train_steptest_step。基本上 tf. keras 模型被​​编译成静态图。为了使 运行 它们处于急切模式,需要完成上述方法。但请注意,在这种情况下使用急切模式可能会减慢您的训练速度。为了集体利益,tf. keras 模型以图形模式编译。

你也可以在编译的时候设置run_eagerly = True,这样也会得到预期的结果。

model = SimpleModel()
model.compile(run_eagerly = True)
pred = model.predict(x)

结果:

True
True