如何设置索引以使用 tf.tensor_scatter_nd_update 获取 3D 张量的最后一维?

How to set indices to get last dimension of 3D tensor with tf.tensor_scatter_nd_update?

我有一个 3D 张量 tensor_a、tensor_a = (510, 1, 6),我想更新 tensor_a[..., 0]tensor_a[...,-1],就像 pytorch 或 numpy 中的 tensor_a[..., 0] = 1.。如何设置 indices 权限以获得与 pytorch 中的 tensor_a[...,0] = 1. 相同的结果?

你可以用 tf.tensor_scatter_nd_update 这样做:

import tensorflow as tf

tensor_a = ...  # Some 3D tensor
idx_to_replace = 0
new_value = 1
s = tf.shape(tensor_a)
i1, i2 = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
i3 = idx_to_replace * tf.ones_like(i1)
idx = tf.stack([i1, i2, i3], axis=-1)
updates = new_value * tf.ones_like(i1, dtype=tensor_a.dtype)
result = tf.tensor_scatter_nd_update(tensor_a, idx, updates)

虽然这不适用于负指数,但您需要将其设为正指数,例如:

idx_to_replace = tf.cond(tf.less(idx_to_replace, 0),
                         lambda: idx_to_replace + s[-1],
                         lambda: idx_to_replace)

但是,要用 ones 替换最后一个维度的第一个索引,您可能会发现简单地执行以下操作会更容易、更快捷:

result = tf.concat([tf.ones_like(tensor_a[..., :1]), tensor_a[..., 1:]], axis=-1)

与最后一个维度类似:

result = tf.concat([tensor_a[..., :-1], tf.ones_like(tensor_a[..., -1:])], axis=-1)