根据索引张量从张量中选择值

Selecting values from tensor based on an index tensor

我有两个矩阵。矩阵 A 包含一些值,矩阵 B 包含索引。矩阵A和B的shape分别是(batch, values)和(batch, indices)

我的目标是 select 矩阵 A 的值基于矩阵 B 沿批量维度的索引。

例如:

# Matrix A
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
       [5., 6., 7., 8., 9.]], dtype=float32)>

# Matrix B
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
       [1, 2]], dtype=int32)>

# Expected Result
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 1.],
       [6., 7.]], dtype=int32)>

如何在 Tensorflow 中实现这一点?

非常感谢!

您可以使用 tf.gather 函数实现此目的。

mat_a = tf.constant([[0., 1., 2., 3., 4.],
                     [5., 6., 7., 8., 9.]])
mat_b = tf.constant([[0, 1], [1, 2]])

out = tf.gather(mat_a, mat_b, batch_dims=1)
out.numpy()
array([[0., 1.],
       [6., 7.]], dtype=float32)