使用 pytorch 根据相应行的给定索引设置张量值
Setting values of a tensor based on given indices of corresponding rows using pytorch
我有一个张量 A
,形状为 (M, N)
,还有另一个张量 B
,形状为 (M, P)
,并且在相应行中具有给定索引的值A
。现在我想将 A
的值与 B
中的相应索引设置为 0
.
例如:
In[1]: import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
A
Out[1]:
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
In[2]: B = torch.tensor([[1,2], [2,3], [3,5]])
B
Out[2]:
tensor([[1, 2],
[2, 3],
[3, 5]])
中的objective是设置第一行索引为1,2
,第二行为2,3
,第三行为3,5
的元素的值A
到 0
的行,即将 A
设置为
tensor([[ 1, 0, 0, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 0, 0, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 0, 5, 0, 7, 8, 9, 10]])
我逐行应用for循环,也试过scatter
:
zeros = torch.zeros(A.shape, dtype=torch.float).to("cuda")
A = A.scatter_(1, B, zeros)
这两种方法工作正常,但性能都很差。实际上,我根据之前的错误推断应该存在一些有效的方法。我最初使用 A[:, B] = 0
。这会将出现在 B
中的所有索引设置为 0
,而不管行。但是,在执行 A[:, B] = 0
.
时训练速度显着提高
有什么方法可以更有效地实现这一点?
这是我会做的:
import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
B = torch.tensor([[1,2], [2,3], [3,5]])
r, c = B.shape
idx0 = torch.arange(r).reshape(-1, 1).repeat(1, c).flatten()
idx1 = B.flatten()
A[idx0, idx1] = 0
输出:
A =
tensor([[ 1, 0, 0, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 0, 0, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 0, 5, 0, 7, 8, 9, 10]])
我有一个张量 A
,形状为 (M, N)
,还有另一个张量 B
,形状为 (M, P)
,并且在相应行中具有给定索引的值A
。现在我想将 A
的值与 B
中的相应索引设置为 0
.
例如:
In[1]: import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
A
Out[1]:
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
In[2]: B = torch.tensor([[1,2], [2,3], [3,5]])
B
Out[2]:
tensor([[1, 2],
[2, 3],
[3, 5]])
中的objective是设置第一行索引为1,2
,第二行为2,3
,第三行为3,5
的元素的值A
到 0
的行,即将 A
设置为
tensor([[ 1, 0, 0, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 0, 0, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 0, 5, 0, 7, 8, 9, 10]])
我逐行应用for循环,也试过scatter
:
zeros = torch.zeros(A.shape, dtype=torch.float).to("cuda")
A = A.scatter_(1, B, zeros)
这两种方法工作正常,但性能都很差。实际上,我根据之前的错误推断应该存在一些有效的方法。我最初使用 A[:, B] = 0
。这会将出现在 B
中的所有索引设置为 0
,而不管行。但是,在执行 A[:, B] = 0
.
有什么方法可以更有效地实现这一点?
这是我会做的:
import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
B = torch.tensor([[1,2], [2,3], [3,5]])
r, c = B.shape
idx0 = torch.arange(r).reshape(-1, 1).repeat(1, c).flatten()
idx1 = B.flatten()
A[idx0, idx1] = 0
输出:
A =
tensor([[ 1, 0, 0, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 0, 0, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 0, 5, 0, 7, 8, 9, 10]])