rewrite function not to trigger ValueError: tf.function-decorated function tried to create variables on non-first call

rewrite function not to trigger ValueError: tf.function-decorated function tried to create variables on non-first call

我在使用@tf.function时遇到问题。

我不确定如何在不创建变量的情况下编写此函数,我认为这将是最好的解决方案,但我非常不确定该怎么做。

我需要创建一个张量对象,在几个特定的​​位置用 0 和 1 填充,我不知道如果不创建 numpy 数组或 tf.Variable.

我正在尝试编写一个函数,以特定方式将我网络的输出张量映射到整数值,有点像 softmax 函数。此函数是从另一个具有 @tf.function 标记的函数调用的,因此不能选择删除此小函数的标记。

@tf.function
def convert_action(self, action_logits_t: tf.Tensor) -> tf.Tensor:
  self.res = tf.Variable(tf.zeros(num_actions, tf.int32))
  if action_logits_t[0][0] > 0.2:
    return self.res[0].assign(1)
  else:
    ccy1 = tf.math.argmax(action_logits_t[0][1:NUM_CCYS + 1])
    self.res = self.res[ccy1 + 1].assign(1) 
    ccy2 = tf.math.argmax(action_logits_t[0][NUM_CCYS + 1:2 * NUM_CCYS + 1])
    self.res = self.res[ccy2 + 1 + NUM_CCYS].assign(1)
    bucket = tf.math.argmax(action_logits_t[0][2 * NUM_CCYS + 1:])
    return self.res[bucket + 1 + 2 * NUM_CCYS].assign(1)

不能使用左侧索引和赋值有点痛苦,但有一些操作可以做你想做的事:tf.scatter_nd

def convert_action(action_logits_t):
    if action_logits_t[0][0] > 0.2:
        return tf.one_hot([0], depth=num_actions)
    else:
        indexes = [
            tf.math.argmax(action_logits_t[0][1 : NUM_CCYS + 1]) + 1,
            tf.math.argmax(action_logits_t[0][NUM_CCYS + 1 : 2 * NUM_CCYS + 1])
            + 1
            + NUM_CCYS,
            tf.math.argmax(action_logits_t[0][2 * NUM_CCYS + 1 :]) + 1 + 2 * NUM_CCYS,
        ]
        return tf.scatter_nd(
            indices=tf.expand_dims(indexes, axis=-1),
            updates=tf.ones_like(indexes),
            shape=[num_actions],
        )

或者另一个,有点创意 tf.one_hot and tf.reduce_sum:

def convert_action(action_logits_t):
    if action_logits_t[0][0] > 0.2:
        return tf.one_hot([0], depth=num_actions)
    else:
        indexes = [
            tf.math.argmax(action_logits_t[0][1:NUM_CCYS + 1]) + 1,
            tf.math.argmax(action_logits_t[0][NUM_CCYS + 1:2 * NUM_CCYS + 1]) + 1 + NUM_CCYS,
            tf.math.argmax(action_logits_t[0][2 * NUM_CCYS + 1:]) + 1 + 2 * NUM_CCYS
        ]
        return tf.reduce_sum(
            tf.one_hot(indexes, depth=num_actions), axis=0
        )