如何更新 2d tf.Variable 的单列?

How to update a single column of a 2d tf.Variable?

假设我有一个 MxN 形 tf.Variable,它存储我的自定义层的一些状态:

import tensorflow as tf
m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 0., 0.],
#            [0., 0., 0., 0.],
#            [0., 0., 0., 0.]], dtype=float32)>

我知道我可以用 v.assign(...) 更新这个变量的值,但是我怎样才能只更新这个变量的一个子部分?例如,我想在给定的列中插入一个给定的向量。

x = tf.ones([m,1])
c = tf.Variable(2)
# update v by inserting x at column c

...这样下面就是 v 的新值:

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 1., 0.],
#            [0., 0., 1., 0.],
#            [0., 0., 1., 0.]], dtype=float32)>

使用 TF 2.2

m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)
x = tf.ones(m)
c = 2

change_v = v[:,c].assign(x)