如何在 Tensorflow 中做类似 numpy 的条件赋值

How to do numpy like conditional assignment in Tensorflow

下面是它在 Numpy 中的工作方式

import numpy as np

vals_for_fives = [12, 18, 22, 33]
arr = np.array([5, 2, 3, 5, 5, 5])
arr[arr == 5] = vals_for_fives  # It is guaranteed that length of vals_for_fives is equal to the number of fives in arr

# now the value of arr is [12, 2, 3, 18, 22, 33]

对于可广播或常量赋值,我们可以在 Tensorflow 中使用 where()assign()。如何在TF中实现上述场景?

tf.experimental.numpy.wheretensorflow v2.5 中的内容。

但现在您可以这样做:

首先找到5的位置:

arr = np.array([5, 2, 3, 5, 5, 5])
where = tf.where(arr==5)
where = tf.cast(where, tf.int32)
print(where)
# <tf.Tensor: id=91, shape=(4, 1), dtype=int32, numpy=
array([[0],
       [3],
       [4],
       [5]])>

然后使用scatter_nd按索引“替换”元素:

tf.scatter_nd(where, tf.constant([12,18,22,23]), tf.constant([5]))
# <tf.Tensor: id=94, shape=(5,), dtype=int32, numpy=array([12,  0,  0, 18
, 22])>

对不是 5 的条目做类似的事情来找到丢失的张量:

tf.scatter_nd(tf.constant([[1], [2]]), tf.constant([2,3]), tf.constant([5]))
# <tf.Tensor: id=98, shape=(5,), dtype=int32, numpy=array([0, 2, 3, 0, 0])>

然后将两个张量相加得到:

<tf.Tensor: id=113, shape=(5,), dtype=int32, numpy=array([12,  2,  3, 1, 8, 22])>