如何使用索引有效地获取张量中每一行的值?
How to efficiently get values of each row in a tensor using indices?
我有一个名为 my_tensor 的张量,形状为 [batch_size, seq_length]
,我有另一个名为 idx 的张量,形状为 [batch_size, 1]
,它由从 0 开始的索引和完成于 "seq_length".
我想使用 idx 中定义的索引提取 my_tensor 的每一行中的值。
我尝试使用 tf.gather_nd
和 tf.gather
但我没有成功。
考虑以下示例:
batch_size = 3
seq_length = 5
idx = [2, 0, 4]
my_tensor = tf.random.uniform(shape=(batch_size, seq_length))
我想获取
的值
[[0, 2],
[1, 0],
[3, 4]]
来自 my_tensor。
我必须对它们做进一步的处理,所以我想同时拥有它们(我不知道这是否可能)并且以一种有效的方式;但是,我想不出任何其他方法。
感谢任何帮助:)
诀窍是首先将您的索引集转换为布尔掩码,然后您可以使用它来减少 my_tensor
,正如您使用 boolean_mask 操作所描述的那样。
您可以通过 one-hot encoding idx
张量来完成此操作。
因此,idx = [2, 0, 4]
我们可以tf.one_hot(idx, seq_length)
将其转换为如下形式:
[ [0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.] ]
然后,将它们放在一起,比如 my_tensor
:
[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
[0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
[0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]
我们可以进行如下操作:
result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))
给予:
[0.42499018, 0.8698617 , 0.7595779 ]
符合预期
我有一个名为 my_tensor 的张量,形状为 [batch_size, seq_length]
,我有另一个名为 idx 的张量,形状为 [batch_size, 1]
,它由从 0 开始的索引和完成于 "seq_length".
我想使用 idx 中定义的索引提取 my_tensor 的每一行中的值。
我尝试使用 tf.gather_nd
和 tf.gather
但我没有成功。
考虑以下示例:
batch_size = 3
seq_length = 5
idx = [2, 0, 4]
my_tensor = tf.random.uniform(shape=(batch_size, seq_length))
我想获取
的值[[0, 2],
[1, 0],
[3, 4]]
来自 my_tensor。
我必须对它们做进一步的处理,所以我想同时拥有它们(我不知道这是否可能)并且以一种有效的方式;但是,我想不出任何其他方法。
感谢任何帮助:)
诀窍是首先将您的索引集转换为布尔掩码,然后您可以使用它来减少 my_tensor
,正如您使用 boolean_mask 操作所描述的那样。
您可以通过 one-hot encoding idx
张量来完成此操作。
因此,idx = [2, 0, 4]
我们可以tf.one_hot(idx, seq_length)
将其转换为如下形式:
[ [0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.] ]
然后,将它们放在一起,比如 my_tensor
:
[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
[0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
[0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]
我们可以进行如下操作:
result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))
给予:
[0.42499018, 0.8698617 , 0.7595779 ]
符合预期