TF2 - tf.function 和 class 变量破坏

TF2 - tf.function and class variables breaking

我正在使用 tensorflow 2.3,如果我将值存储在 class 而不是返回它

例如

import tensorflow as tf

class Test:
    def __init__(self):
        self.count = tf.convert_to_tensor(0)
        
    @tf.function
    def incr(self):
        self.count += 1
        return self.count
        
t = Test()
count = t.incr()
count == t.count

产生以下错误

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: add:0

当我检查 count 与 t.count 的值时,我看到了

如何解决此问题,以便 t.count 存储相同的值?

您可以使用 tf.Variable .

class Test:
        def __init__(self):
           # self.count = tf.convert_to_tensor(0)
             self.count = tf.Variable(0)
        @tf.function
        def incr(self):
            self.count.assign_add(1)
            return self.count
     
    t = Test()
count = t.incr()
print(count) #tf.Tensor(1, shape=(), dtype=int32)
count = t.count
print(count)#<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=1>

张量是不可变的。在图内部使用时(tf.operation和tf.tensor的结构),张量不会急切执行;所以你不能得到它的值,并且在图中被称为添加操作的一部分:t.count: .

count = tf.constant(0)
print(count)# count is eager_tensor : tf.Tensor(0, shape=(), dtype=int32)
@tf.function
def incr():
        global count # we use count tensor inside graph contect
        count+=1
        return count
        
c = incr()
print(count)# now count is graph_tensor :Tensor("add:0", shape=(),dtype=int32)