torch.gather(...) 调用的设置结果
Setting results of torch.gather(...) calls
我有一个形状为 n x m 的二维 pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后 然后也将新值 设置为索引的结果。
示例:
data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
我想 select 每行的指定索引(这将是 [1,5,7]
但随后还将这些值设置为另一个数字 - 例如 42
我可以 select 通过以下方式逐行显示所需的列:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
很好,但我现在想更改这些值,并且更改也会影响 data
张量。
我可以用它做我想做的事,但它似乎非常不符合 pythonic:
max_index = torch.max(indices)
for i in range(0, max_index + 1):
mask = (indices == i).nonzero(as_tuple=True)[0]
data[mask, i] = 42
print(data)
# tensor([[ 0, 42, 2],
# [ 3, 4, 42],
# [ 6, 42, 8]])
关于如何更优雅地做到这一点的任何提示?
您要找的是带有 value
选项的 torch.scatter_
。
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index
is specified by
its index in src for dimension != dim
and by the corresponding value
in index for dimension = dim
.
With 2D tensors as input and dim=1
, the operation is:
self[i][index[i][j]] = src[i][j]
虽然没有提到值参数...
对于value=42
和dim=1
,这将对数据产生以下影响:
data[i][index[i][j]] = 42
此处就地应用:
>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42, 2],
[ 3, 4, 42],
[ 6, 42, 8]])
我有一个形状为 n x m 的二维 pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后 然后也将新值 设置为索引的结果。
示例:
data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
我想 select 每行的指定索引(这将是 [1,5,7]
但随后还将这些值设置为另一个数字 - 例如 42
我可以 select 通过以下方式逐行显示所需的列:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
很好,但我现在想更改这些值,并且更改也会影响 data
张量。
我可以用它做我想做的事,但它似乎非常不符合 pythonic:
max_index = torch.max(indices)
for i in range(0, max_index + 1):
mask = (indices == i).nonzero(as_tuple=True)[0]
data[mask, i] = 42
print(data)
# tensor([[ 0, 42, 2],
# [ 3, 4, 42],
# [ 6, 42, 8]])
关于如何更优雅地做到这一点的任何提示?
您要找的是带有 value
选项的 torch.scatter_
。
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensorsrc
intoself
at the indices specified in theindex
tensor. For each value insrc
, its outputindex
is specified by its index in src fordimension != dim
and by the corresponding value in index fordimension = dim
.With 2D tensors as input and
dim=1
, the operation is:self[i][index[i][j]] = src[i][j]
虽然没有提到值参数...
对于value=42
和dim=1
,这将对数据产生以下影响:
data[i][index[i][j]] = 42
此处就地应用:
>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42, 2],
[ 3, 4, 42],
[ 6, 42, 8]])