torch.gather(...) 调用的设置结果

Setting results of torch.gather(...) calls

我有一个形状为 n x m 的二维 pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后 然后也将新值 设置为索引的结果。

示例:

data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

我想 select 每行的指定索引(这将是 [1,5,7] 但随后还将这些值设置为另一个数字 - 例如 42

我可以 select 通过以下方式逐行显示所需的列:

data.gather(1, indices)
tensor([[1],
        [5],
        [7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather 
                                # does not use the same storage as the original tensor

很好,但我现在想更改这些值,并且更改也会影响 data 张量。

我可以用它做我想做的事,但它似乎非常不符合 pythonic:

max_index = torch.max(indices)
for i in range(0, max_index + 1):
  mask = (indices == i).nonzero(as_tuple=True)[0]
  data[mask, i] = 42
print(data)
# tensor([[ 0, 42,  2],
#         [ 3,  4, 42],
#         [ 6, 42,  8]])

关于如何更优雅地做到这一点的任何提示?

您要找的是带有 value 选项的 torch.scatter_

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

With 2D tensors as input and dim=1, the operation is:
self[i][index[i][j]] = src[i][j]

虽然没有提到值参数...


对于value=42dim=1,这将对数据产生以下影响:

data[i][index[i][j]] = 42

此处就地应用:

>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42,  2],
        [ 3,  4, 42],
        [ 6, 42,  8]])