tf.Variable assign 方法打破了 tf.GradientTape

tf.Variable assign method breaks the tf.GradientTape

当我使用tf.Variable的赋值方法改变一个变量的值时,它制动了tf.Gradient,e。例如,请参阅下面的示例代码:

(注意:我只对 TensorFlow 2 感兴趣。)

x = tf.Variable([[2.0,3.0,4.0], [1.,10.,100.]])
patch = tf.Variable([[0., 1.], [2., 3.]])
with tf.GradientTape() as g:
    g.watch(patch)
    x[:2,:2].assign(patch)
    y = tf.tensordot(x, tf.transpose(x), axes=1)
    o = tf.reduce_mean(y)
do_dpatch = g.gradient(o, patch)

然后它给了我 None 作为 do_dpatch

请注意,如果我执行以下操作,它会完全正常工作:

x = tf.Variable([[2.0,3.0,4.0], [1.,10.,100.]])
patch = tf.Variable([[0., 1.], [2., 3.]])
with tf.GradientTape() as g:
    g.watch(patch)
    x[:2,:2].assign(patch)
    y = tf.tensordot(x, tf.transpose(x), axes=1)
    o = tf.reduce_mean(y)
do_dx = g.gradient(o, x)

并给我:

>>>do_dx 
<tf.Tensor: id=106, shape=(2, 3), dtype=float32, numpy=
array([[ 1.,  2., 52.],
       [ 1.,  2., 52.]], dtype=float32)>

这种行为确实有道理。让我们以你的第一个例子为例

x = tf.Variable([[2.0,3.0,4.0], [1.,10.,100.]])
patch = tf.Variable([[1., 1.], [1., 1.]])
with tf.GradientTape() as g:
    g.watch(patch)
    x[:2,:2].assign(patch)
    y = tf.tensordot(x, tf.transpose(x), axes=1)
dy_dx = g.gradient(y, patch)

您正在计算 dy/d(补丁)。但是你的 y 只依赖 x 而不是 patch。是的,您确实从 patchx 赋值。但是此操作不包含对 patch 变量的引用。它只是复制值。

简而言之,您正在尝试获得一个梯度 w.r.t 它不依赖的东西。所以你会得到None.

让我们看看第二个示例及其工作原理。

x = tf.Variable([[2.0,3.0,4.0], [1.,10.,100.]])
with tf.GradientTape() as g:
    g.watch(x)
    x[:2,:2].assign([[1., 1.], [1., 1.]])
  y = tf.tensordot(x, tf.transpose(x), axes=1)
dy_dx = g.gradient(y, x)

这个例子非常好。 Y 取决于 x,您正在计算 dy/dx。所以你会在这个例子中得到实际的渐变。

如解释的那样HERE(请参阅下面来自 alextp 的引述)tf.assign 不支持渐变。

"There is no plan to add a gradient to tf.assign because it's not possible in general to connect the uses of the assigned variable with the graph which assigned it."

所以,上面的问题可以通过下面的代码解决:

x= tf.Variable([[0.0,0.0,4.0], [0.,0.,100.]])
patch = tf.Variable([[0., 1.], [2., 3.]])
with tf.GradientTape() as g:
    g.watch(patch)
    padding = tf.constant([[0, 0], [0, 1]])
    padde_patch = tf.pad(patch, padding, mode='CONSTANT', constant_values=0)
    revised_x = x+ padde_patch
    y = tf.tensordot(revised_x, tf.transpose(revised_x), axes=1)
    o = tf.reduce_mean(y)
do_dpatch = g.gradient(o, patch)

这导致

do_dpatch

<tf.Tensor: id=65, shape=(2, 2), dtype=float32, numpy=
array([[1., 2.],
       [1., 2.]], dtype=float32)>