如何向量化张量中的修改元素

How to vectorize modifying elements in a tensor

给定一个:

tf.tensor A 形状 N*M, (A = tf.zeros((N, M)))
tf.tensor indices 形状 N*k (k<=M)。每个 ith 行包含张量的一些索引 A
tf.tensor upates 形状为 N*K 。每个 ith 行包含用于更新张量的值 A

目标:更新 A 的元素,其索引存在于 indices,值为 updates

循环使用tf.scatter_nd

result = []
for idx in range(N):
    index = tf.reshape(indices[idx], (-1, 1))
    updates = tf.convert_to_tensor(updates[idx])
    scatter = tf.scatter_nd(index, updates, shape=tf.constant([M]))
    target.append(scatter)
result = tf.stack(result, axis=0)

这个循环显然适用于 N 较小的情况。

问题:如何将其向量化为 运行 更快。

如果第一个张量 A 总是由零组成,您可以通过调用 tf.scatter_nd:

来实现
import tensorflow as tf

indices = ... # shape: (n, k)
updates = ... # shape: (n, k)
s = tf.shape(indices, out_type=indices.dtype)
n = s[0]
k = s[1]
idx_row = tf.tile(tf.expand_dims(tf.range(n), 1), (1, k))
idx_full = tf.stack([idx_row , indices], axis=-1)
result = tf.scatter_nd(idx_full, updates, [n, m])

如果初始 A 包含其他内容,您将执行基本相同的操作,但使用 tf.tensor_scatter_nd_update 代替:

A = ... # shape: (n, m)
result = tf.tensor_scatter_nd_update(A, idx_full, updates)