如何使用索引有效地获取张量中每一行的值?

How to efficiently get values of each row in a tensor using indices?

我有一个名为 my_tensor 的张量,形状为 [batch_size, seq_length],我有另一个名为 idx 的张量,形状为 [batch_size, 1],它由从 0 开始的索引和完成于 "seq_length".

我想使用 idx 中定义的索引提取 my_tensor 的每一行中的值。

我尝试使用 tf.gather_ndtf.gather 但我没有成功。

考虑以下示例:

batch_size = 3
seq_length = 5
idx = [2, 0, 4]

my_tensor = tf.random.uniform(shape=(batch_size, seq_length))

我想获取

的值
[[0, 2],
 [1, 0],
 [3, 4]]

来自 my_tensor。

我必须对它们做进一步的处理,所以我想同时拥有它们(我不知道这是否可能)并且以一种有效的方式;但是,我想不出任何其他方法。

感谢任何帮助:)

诀窍是首先将您的索引集转换为布尔掩码,然后您可以使用它来减少 my_tensor,正如您使用 boolean_mask 操作所描述的那样。

您可以通过 one-hot encoding idx 张量来完成此操作。

因此,idx = [2, 0, 4] 我们可以tf.one_hot(idx, seq_length) 将其转换为如下形式:

[ [0., 0., 1., 0., 0.],
  [1., 0., 0., 0., 0.],
  [0., 0., 0., 0., 1.] ]

然后,将它们放在一起,比如 my_tensor:

[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
  [0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
  [0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]

我们可以进行如下操作:

result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))

给予:

[0.42499018, 0.8698617 , 0.7595779 ]

符合预期