批量处理多维张量的 Pytorch 索引

Pytorch-index on multiple dimension tensor in a batch

给定一维张量:

A = torch.tensor([1, 2, 3, 4])

假设我们有一些“索引器张量”

ind1 = torch.tensor([3, 0, 1])
ind2 = torch.tensor([[3, 0], [1, 2]])

因为我们 运行 A[ind1] & A[ind2]

我们得到结果 tensor([4, 1, 2]) & tensor([[4, 1],[2, 3]]) 它与索引张量(ind1 和 ind2)的形状相同,它的值是从张量 A 映射而来的。

我想问一下如何索引高维张量


目前我有一个解决方案:

对于 N 维张量 A,假设我们有索引张量 IND, IND 类似于 [[i11, i12, ... i1N], [i21, i22, ... i2N], ...[iM1, i22, ... iMN],其中 M 是索引元素的数量。 我们可以将IND分成N个张量,其中

IND_1 = torch.tensor([i11, i21, ... iM1])
...
IND_N = torch.tensor([i1N, i2N, ... iMN])

当我们 运行 A[IND_1, ... IND_N] 时,我们得到了 tensor(v1, v2, ... vM)

示例:

A = tensor([[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # [2 * 2 * 2]
ind1 = tensor([1, 0, 1])
ind2 = tensor([1, 1, 0])
ind3 = tensor([0, 1, 0])
A[ind1, ind2, ind3]
=> tensor([7, 4, 5])
# and the good thing is you can control the shape of result tensor by modifying the inds' shape.
ind1 = tensor([[0, 0], [1, 0]])
ind2 = tensor([[1, 1], [0, 1]])
ind3 = tensor([[0, 1], [0, 0]])
A[ind1, ind2, ind3]
=> tensor([[3, 4],[5, 3]])  # same as inds' shape

谁有更优雅的解决方案?

1- 在扁平输入上使用解开索引的手动方法。

如果您想在任意数量的轴(A 的所有轴)上进行索引,那么一种直接的方法是展平所有维度并展开索引。假设 A 是 3D,我们想使用 ind1ind2ind3:

的堆栈对其进行索引
>>> ind = torch.stack((ind1, ind2, ind3))

您可以先使用 A 的步幅来解开索引:

>>> unraveled = torch.tensor(A.stride()) @ ind.flatten(1)

然后将 A 展平,用 unraveled 索引并整形为最终形式:

>>> A.flatten()[unraveled].reshape_as(ind[0])

2- 使用 ind.

的简单分割

您实际上可以使用 torch.chunk 执行相同的操作:

>>> A[ind.chunk(len(ind))][0]

或者 torch.split 是相同的:

>>> A[ind.split(1)][0]

3- single-axis 索引的初始答案。

让我们举一个最小的 multi-dimensional 例子,其中 A 是一个二维张量,定义为:

>>> A = torch.tensor([[1, 2, 3, 4],
                      [5, 6, 7, 8]])

根据您对问题的描述:

the same shape of index tensor and its value are mapped from tensor A.

然后索引张量需要与索引张量 A 具有相同的形状,因为它不再是平坦的。 否则,A(形状(2, 4))被ind1(形状(3,))索引的结果会是什么?

如果您在单个维度上建立索引,那么您可以利用 torch.gather:

>>> A.gather(1, ind2)
tensor([[4, 1],
        [6, 7]])