@tf.function 中的 If-else

If-else in @tf.function

我想定义一个自定义的LearningRateSchedule,但AutoGraph似乎无法转换它。下面的代码在没有@tf.function 的情况下工作正常。但是在使用 @tf.function

时会引发错误
def linear_interpolation(l, r, alpha):
    return l + alpha * (r - l)

class TFPiecewiseSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    # This class currently cannot be used in @tf.function, 
    # Since tf.cond See the following link for details
    def __init__(self, endpoints, end_learning_rate=None, name=None):
        """Piecewise schedule.
        endpoints: [(int, int)]
            list of pairs `(time, value)` meanining that schedule should output
            `value` when `t==time`. All the values for time must be sorted in
            an increasing order. When t is between two times, e.g. `(time_a, value_a)`
            and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
            `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
            time passed between `time_a` and `time_b` for time `t`.
        outside_value: float
            if the value is requested outside of all the intervals sepecified in
            `endpoints` this value is returned. If None then AssertionError is
            raised when outside value is requested.
        """
        super().__init__()
        idxes = [e[0] for e in endpoints]
        assert idxes == sorted(idxes)
        self.end_learning_rate = end_learning_rate or endpoints[-1][1]
        self.endpoints = endpoints
        self.name=name

    def __call__(self, step):
        if step < self.endpoints[0][0]:
            return self.endpoints[0][1]
        else:
            for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
                if l_t <= step < r_t:
                    alpha = float(step - l_t) / (r_t - l_t)
                    return linear_interpolation(l, r, alpha)

        # t does not belong to any of the pieces, so doom.
        assert self.end_learning_rate is not None
        return self.end_learning_rate

    def get_config(self):
        return dict(
            endpoints=self.endpoints,
            end_learning_rate=self.end_learning_rate,
            name=self._name,
        )

lr = TFPiecewiseSchedule([[10, 1e-3], [20, 1e-4]])
@tf.function
def f(x):
    l = layers.Dense(10)
    with tf.GradientTape() as tape:
        y = l(x)
        loss = tf.reduce_mean(y**2)

    grads = tape.gradient(loss, l.trainable_variables)
    opt = tf.keras.optimizers.Adam(lr)
    opt.apply_gradients(zip(grads, l.trainable_variables))

f(tf.random.normal((2, 3)))

错误消息说:

:10 f * opt.apply_gradients(zip(grads, l.trainable_variables))

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:437 apply_gradients apply_state = self._prepare(var_list)

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:614 _prepare self._prepare_local(var_device, var_dtype, apply_state)

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/adam.py:154 _prepare_local super(Adam, self)._prepare_local(var_device, var_dtype, apply_state)

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:620 _prepare_local lr_t = array_ops.identity(self._decayed_lr(var_dtype))

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:672 _decayed_lr lr_t = math_ops.cast(lr_t(local_step), var_dtype)

:32 call if step < self.endpoints[0][0]:

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:765 bool self._disallow_bool_casting()

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:531 _disallow_bool_casting "using a tf.Tensor as a Python bool")

/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:518 _disallow_when_autograph_enabled " decorating it directly with @tf.function.".format(task))

OperatorNotAllowedInGraphError: using a tf.Tensor as a Python bool is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

我认为错误是因为if语句引起的,所以我用下面的代码替换了__call__函数的内容。但是几乎出现了同样的错误。

        def compute_lr(step):
            for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
                if l_t <= step < r_t:
                    alpha = float(step - l_t) / (r_t - l_t)
                    return linear_interpolation(l, r, alpha)

            # t does not belong to any of the pieces, so doom.
            assert self.end_learning_rate is not None
            return self.end_learning_rate
        return tf.cond(tf.less(step, self.endpoints[0][0]), lambda: self.endpoints[0][1], lambda: compute_lr(step))

我应该怎么做才能让代码如我所愿?

错误信息被markdown格式化程序乱码了,但似乎__call__函数本身没有被AutoGraph处理。在错误消息中,转换后的函数标有星号。这是 Adam 优化器中的一个错误。反正你直接用tf.function注释就可以了,会捡起来的:

    @tf.function
    def __call__(self, step):

也就是说,AutoGraph 不喜欢代码中的一些内容:zip、从循环返回、链式不等式 - 尽可能使用基本结构更安全。可悲的是,你仍然得到的错误有点令人困惑。像这样重写应该可行:

    @tf.function
    def __call__(self, step):
        if step < self.endpoints[0][0]:
            return self.endpoints[0][1]
        else:
            # Can't return from a loop
            lr = self.end_learning_rate
            # Since it needs to break based on the value of a tensor, loop
            # needs to be a tf.while_loop
            for pair in tf.stack([self.endpoints[:-1], self.endpoints[1:]], axis=1):
                left, right = tf.unstack(pair)
                l_t, l = tf.unstack(left)
                r_t, r = tf.unstack(right)
                # Chained inequalities not supported yet
                if l_t <= step and step < r_t:
                    alpha = float(step - l_t) / (r_t - l_t)
                    lr = linear_interpolation(l, r, alpha)
                    break

        return lr

最后一个问题 - tf.function 不喜欢创建变量,因此您需要将层的创建和优化器移到外面:

lr = TFPiecewiseSchedule([[10, 1e-3], [20, 1e-4]])
l = layers.Dense(10)
opt = tf.keras.optimizers.Adam(lr)
@tf.function
def f(x):
  ...

希望对您有所帮助!