批量处理多维张量的 Pytorch 索引
Pytorch-index on multiple dimension tensor in a batch
给定一维张量:
A = torch.tensor([1, 2, 3, 4])
假设我们有一些“索引器张量”
ind1 = torch.tensor([3, 0, 1])
ind2 = torch.tensor([[3, 0], [1, 2]])
因为我们 运行 A[ind1]
& A[ind2]
我们得到结果 tensor([4, 1, 2])
& tensor([[4, 1],[2, 3]])
它与索引张量(ind1 和 ind2)的形状相同,它的值是从张量 A 映射而来的。
我想问一下如何索引高维张量?
目前我有一个解决方案:
对于 N 维张量 A,假设我们有索引张量 IND,
IND 类似于 [[i11, i12, ... i1N], [i21, i22, ... i2N], ...[iM1, i22, ... iMN]
,其中 M 是索引元素的数量。
我们可以将IND分成N个张量,其中
IND_1 = torch.tensor([i11, i21, ... iM1])
...
IND_N = torch.tensor([i1N, i2N, ... iMN])
当我们 运行 A[IND_1, ... IND_N]
时,我们得到了 tensor(v1, v2, ... vM)
示例:
A = tensor([[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # [2 * 2 * 2]
ind1 = tensor([1, 0, 1])
ind2 = tensor([1, 1, 0])
ind3 = tensor([0, 1, 0])
A[ind1, ind2, ind3]
=> tensor([7, 4, 5])
# and the good thing is you can control the shape of result tensor by modifying the inds' shape.
ind1 = tensor([[0, 0], [1, 0]])
ind2 = tensor([[1, 1], [0, 1]])
ind3 = tensor([[0, 1], [0, 0]])
A[ind1, ind2, ind3]
=> tensor([[3, 4],[5, 3]]) # same as inds' shape
谁有更优雅的解决方案?
1- 在扁平输入上使用解开索引的手动方法。
如果您想在任意数量的轴(A
的所有轴)上进行索引,那么一种直接的方法是展平所有维度并展开索引。假设 A
是 3D,我们想使用 ind1
、ind2
和 ind3
:
的堆栈对其进行索引
>>> ind = torch.stack((ind1, ind2, ind3))
您可以先使用 A
的步幅来解开索引:
>>> unraveled = torch.tensor(A.stride()) @ ind.flatten(1)
然后将 A
展平,用 unraveled
索引并整形为最终形式:
>>> A.flatten()[unraveled].reshape_as(ind[0])
2- 使用 ind
.
的简单分割
您实际上可以使用 torch.chunk
执行相同的操作:
>>> A[ind.chunk(len(ind))][0]
或者 torch.split
是相同的:
>>> A[ind.split(1)][0]
3- single-axis 索引的初始答案。
让我们举一个最小的 multi-dimensional 例子,其中 A
是一个二维张量,定义为:
>>> A = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
根据您对问题的描述:
the same shape of index tensor and its value are mapped from tensor A
.
然后索引张量需要与索引张量 A
具有相同的形状,因为它不再是平坦的。 否则,A
(形状(2, 4)
)被ind1
(形状(3,)
)索引的结果会是什么?
如果您在单个维度上建立索引,那么您可以利用 torch.gather
:
>>> A.gather(1, ind2)
tensor([[4, 1],
[6, 7]])
给定一维张量:
A = torch.tensor([1, 2, 3, 4])
假设我们有一些“索引器张量”
ind1 = torch.tensor([3, 0, 1])
ind2 = torch.tensor([[3, 0], [1, 2]])
因为我们 运行 A[ind1]
& A[ind2]
我们得到结果 tensor([4, 1, 2])
& tensor([[4, 1],[2, 3]])
它与索引张量(ind1 和 ind2)的形状相同,它的值是从张量 A 映射而来的。
我想问一下如何索引高维张量?
目前我有一个解决方案:
对于 N 维张量 A,假设我们有索引张量 IND,
IND 类似于 [[i11, i12, ... i1N], [i21, i22, ... i2N], ...[iM1, i22, ... iMN]
,其中 M 是索引元素的数量。
我们可以将IND分成N个张量,其中
IND_1 = torch.tensor([i11, i21, ... iM1])
...
IND_N = torch.tensor([i1N, i2N, ... iMN])
当我们 运行 A[IND_1, ... IND_N]
时,我们得到了 tensor(v1, v2, ... vM)
示例:
A = tensor([[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # [2 * 2 * 2]
ind1 = tensor([1, 0, 1])
ind2 = tensor([1, 1, 0])
ind3 = tensor([0, 1, 0])
A[ind1, ind2, ind3]
=> tensor([7, 4, 5])
# and the good thing is you can control the shape of result tensor by modifying the inds' shape.
ind1 = tensor([[0, 0], [1, 0]])
ind2 = tensor([[1, 1], [0, 1]])
ind3 = tensor([[0, 1], [0, 0]])
A[ind1, ind2, ind3]
=> tensor([[3, 4],[5, 3]]) # same as inds' shape
谁有更优雅的解决方案?
1- 在扁平输入上使用解开索引的手动方法。
如果您想在任意数量的轴(A
的所有轴)上进行索引,那么一种直接的方法是展平所有维度并展开索引。假设 A
是 3D,我们想使用 ind1
、ind2
和 ind3
:
>>> ind = torch.stack((ind1, ind2, ind3))
您可以先使用 A
的步幅来解开索引:
>>> unraveled = torch.tensor(A.stride()) @ ind.flatten(1)
然后将 A
展平,用 unraveled
索引并整形为最终形式:
>>> A.flatten()[unraveled].reshape_as(ind[0])
2- 使用 ind
.
的简单分割
您实际上可以使用 torch.chunk
执行相同的操作:
>>> A[ind.chunk(len(ind))][0]
或者 torch.split
是相同的:
>>> A[ind.split(1)][0]
3- single-axis 索引的初始答案。
让我们举一个最小的 multi-dimensional 例子,其中 A
是一个二维张量,定义为:
>>> A = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
根据您对问题的描述:
the same shape of index tensor and its value are mapped from tensor
A
.
然后索引张量需要与索引张量 A
具有相同的形状,因为它不再是平坦的。 否则,A
(形状(2, 4)
)被ind1
(形状(3,)
)索引的结果会是什么?
如果您在单个维度上建立索引,那么您可以利用 torch.gather
:
>>> A.gather(1, ind2)
tensor([[4, 1],
[6, 7]])