如何在 PyTorch 中对二维进行 select 索引?
How to select index over two dimension in PyTorch?
给定 a = torch.randn(3, 2, 4, 5)
,我如何 select 像 (2, :, 0, :), (1, :, 1, :), (2, :, 2, :), (0, :, 3, :)
这样的子张量(生成的张量大小为 (2, 4, 5)
或 (4, 2, 5)
?
而 a[2, :, 0, :]
给出
0.5580 -0.0337 1.0048 -0.5044 0.6784
-1.6117 1.0084 1.1886 0.1278 0.3739
[torch.FloatTensor of size 2x5]
然而,a[[2, 1, 2, 0], :, [0, 1, 2, 3], :]
给出
TypeError: Performing basic indexing on a tensor and encountered an error indexing dim 0 with an object of type list. The only supported types are integers, slices, numpy scalars, or if indexing with a torch.LongTensor or torch.ByteTensor only a single Tensor may be passed.
虽然 numpy
returns 一个 (4, 2, 5)
张量成功。
对你有用吗?
import torch
a = torch.randn(3, 2, 4, 5)
print(a.size())
b = [a[2, :, 0, :], a[1, :, 1, :], a[2, :, 2, :], a[0, :, 3, :]]
b = torch.stack(b, 0)
print(b.size()) # torch.Size([4, 2, 5])
给定 a = torch.randn(3, 2, 4, 5)
,我如何 select 像 (2, :, 0, :), (1, :, 1, :), (2, :, 2, :), (0, :, 3, :)
这样的子张量(生成的张量大小为 (2, 4, 5)
或 (4, 2, 5)
?
而 a[2, :, 0, :]
给出
0.5580 -0.0337 1.0048 -0.5044 0.6784
-1.6117 1.0084 1.1886 0.1278 0.3739
[torch.FloatTensor of size 2x5]
然而,a[[2, 1, 2, 0], :, [0, 1, 2, 3], :]
给出
TypeError: Performing basic indexing on a tensor and encountered an error indexing dim 0 with an object of type list. The only supported types are integers, slices, numpy scalars, or if indexing with a torch.LongTensor or torch.ByteTensor only a single Tensor may be passed.
虽然 numpy
returns 一个 (4, 2, 5)
张量成功。
对你有用吗?
import torch
a = torch.randn(3, 2, 4, 5)
print(a.size())
b = [a[2, :, 0, :], a[1, :, 1, :], a[2, :, 2, :], a[0, :, 3, :]]
b = torch.stack(b, 0)
print(b.size()) # torch.Size([4, 2, 5])