使用 pytorch 根据相应行的给定索引设置张量值

Setting values of a tensor based on given indices of corresponding rows using pytorch

我有一个张量 A,形状为 (M, N),还有另一个张量 B,形状为 (M, P),并且在相应行中具有给定索引的值A。现在我想将 A 的值与 B 中的相应索引设置为 0.

例如:

In[1]: import torch
       A = torch.tensor([range(1,11), range(1,11), range(1,11)])
       A
Out[1]: 
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
In[2]: B = torch.tensor([[1,2], [2,3], [3,5]])
       B
Out[2]: 
tensor([[1, 2],
        [2, 3],
        [3, 5]])

中的objective是设置第一行索引为1,2,第二行为2,3,第三行为3,5的元素的值A0 的行,即将 A 设置为

tensor([[ 1,  0,  0,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  0,  0,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  0,  5,  0,  7,  8,  9, 10]])

我逐行应用for循环,也试过scatter:

zeros = torch.zeros(A.shape, dtype=torch.float).to("cuda")
A = A.scatter_(1, B, zeros)

这两种方法工作正常,但性能都很差。实际上,我根据之前的错误推断应该存在一些有效的方法。我最初使用 A[:, B] = 0。这会将出现在 B 中的所有索引设置为 0,而不管行。但是,在执行 A[:, B] = 0.

时训练速度显着提高

有什么方法可以更有效地实现这一点?

这是我会做的:

import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
B = torch.tensor([[1,2], [2,3], [3,5]])
r, c = B.shape
idx0 = torch.arange(r).reshape(-1, 1).repeat(1, c).flatten()
idx1 = B.flatten()
A[idx0, idx1] = 0

输出:

A = 
tensor([[ 1,  0,  0,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  0,  0,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  0,  5,  0,  7,  8,  9, 10]])