根据索引张量从张量中选择值
Selecting values from tensor based on an index tensor
我有两个矩阵。矩阵 A 包含一些值,矩阵 B 包含索引。矩阵A和B的shape分别是(batch, values)和(batch, indices)
我的目标是 select 矩阵 A 的值基于矩阵 B 沿批量维度的索引。
例如:
# Matrix A
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]], dtype=float32)>
# Matrix B
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
[1, 2]], dtype=int32)>
# Expected Result
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 1.],
[6., 7.]], dtype=int32)>
如何在 Tensorflow 中实现这一点?
非常感谢!
您可以使用 tf.gather
函数实现此目的。
mat_a = tf.constant([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
mat_b = tf.constant([[0, 1], [1, 2]])
out = tf.gather(mat_a, mat_b, batch_dims=1)
out.numpy()
array([[0., 1.],
[6., 7.]], dtype=float32)
我有两个矩阵。矩阵 A 包含一些值,矩阵 B 包含索引。矩阵A和B的shape分别是(batch, values)和(batch, indices)
我的目标是 select 矩阵 A 的值基于矩阵 B 沿批量维度的索引。
例如:
# Matrix A
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]], dtype=float32)>
# Matrix B
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
[1, 2]], dtype=int32)>
# Expected Result
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 1.],
[6., 7.]], dtype=int32)>
如何在 Tensorflow 中实现这一点?
非常感谢!
您可以使用 tf.gather
函数实现此目的。
mat_a = tf.constant([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
mat_b = tf.constant([[0, 1], [1, 2]])
out = tf.gather(mat_a, mat_b, batch_dims=1)
out.numpy()
array([[0., 1.],
[6., 7.]], dtype=float32)