在pytorch中使用索引提取张量数据
Extracting tensor data with index in pytorch
我希望以某种方式对张量进行索引。
假设我的数据,X形张量(1, 3, 16, 9)
是
tensor([[[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]],
[[ 0., 0., 0., 0., 17., 18., 0., 21., 22.],
[ 0., 0., 0., 17., 18., 19., 21., 22., 23.],
[ 0., 0., 0., 18., 19., 20., 22., 23., 24.],
[ 0., 0., 0., 19., 20., 0., 23., 24., 0.],
[ 0., 17., 18., 0., 21., 22., 0., 25., 26.],
[17., 18., 19., 21., 22., 23., 25., 26., 27.],
[18., 19., 20., 22., 23., 24., 26., 27., 28.],
[19., 20., 0., 23., 24., 0., 27., 28., 0.],
[ 0., 21., 22., 0., 25., 26., 0., 29., 30.],
[21., 22., 23., 25., 26., 27., 29., 30., 31.],
[22., 23., 24., 26., 27., 28., 30., 31., 32.],
[23., 24., 0., 27., 28., 0., 31., 32., 0.],
[ 0., 25., 26., 0., 29., 30., 0., 0., 0.],
[25., 26., 27., 29., 30., 31., 0., 0., 0.],
[26., 27., 28., 30., 31., 32., 0., 0., 0.],
[27., 28., 0., 31., 32., 0., 0., 0., 0.]],
[[ 0., 0., 0., 0., 33., 34., 0., 37., 38.],
[ 0., 0., 0., 33., 34., 35., 37., 38., 39.],
[ 0., 0., 0., 34., 35., 36., 38., 39., 40.],
[ 0., 0., 0., 35., 36., 0., 39., 40., 0.],
[ 0., 33., 34., 0., 37., 38., 0., 41., 42.],
[33., 34., 35., 37., 38., 39., 41., 42., 43.],
[34., 35., 36., 38., 39., 40., 42., 43., 44.],
[35., 36., 0., 39., 40., 0., 43., 44., 0.],
[ 0., 37., 38., 0., 41., 42., 0., 45., 46.],
[37., 38., 39., 41., 42., 43., 45., 46., 47.],
[38., 39., 40., 42., 43., 44., 46., 47., 48.],
[39., 40., 0., 43., 44., 0., 47., 48., 0.],
[ 0., 41., 42., 0., 45., 46., 0., 0., 0.],
[41., 42., 43., 45., 46., 47., 0., 0., 0.],
[42., 43., 44., 46., 47., 48., 0., 0., 0.],
[43., 44., 0., 47., 48., 0., 0., 0., 0.]]]]
我想要将 (row_index % n) == i
(比如 n = 4
和 i = 0 to 3
)保存在另一个张量 Y
.
中的那些行
例如,对于数据X[0][0]
:
[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]
我想要一个包含以下数据的张量,它基本上是 row_index % 4 == 0
(此处 i = 0
)所在行的集合:
[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]]
同样,i = 1
、row_index % 4 == i
看起来像:
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]]
当 i = 2
、row_index % 4 == i
:
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]]
当 i = 3
、row_index % 4 == i
:
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]
我已经尝试对其进行硬编码,但当数据变大并且大小变得动态时,它似乎并不实用,我认为会有更好的方法来实现它。
temp0 = data[0][0][0][:]
temp1 = data[0][0][4][:]
temp2 = data[0][0][8][:]
temp3 = data[0][0][12][:]
temp = torch.stack([temp0,temp1,temp2,temp3],dim = 0)
另外,如果结果能以一个张量返回就好了:
tensor Y = ([[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]],
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]],
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]],
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]])
首先,要获得每个爱国,你可以试试这个:
import torch
data = torch.tensor([[[[0., 0., 0., 0., 1., 2., 0., 5., 6.],
[0., 0., 0., 1., 2., 3., 5., 6., 7.],
[0., 0., 0., 2., 3., 4., 6., 7., 8.],
[0., 0., 0., 3., 4., 0., 7., 8., 0.],
[0., 1., 2., 0., 5., 6., 0., 9., 10.],
[1., 2., 3., 5., 6., 7., 9., 10., 11.],
[2., 3., 4., 6., 7., 8., 10., 11., 12.],
[3., 4., 0., 7., 8., 0., 11., 12., 0.],
[0., 5., 6., 0., 9., 10., 0., 13., 14.],
[5., 6., 7., 9., 10., 11., 13., 14., 15.],
[6., 7., 8., 10., 11., 12., 14., 15., 16.],
[7., 8., 0., 11., 12., 0., 15., 16., 0.],
[0., 9., 10., 0., 13., 14., 0., 0., 0.],
[9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 17., 18., 0., 21., 22.],
[0., 0., 0., 17., 18., 19., 21., 22., 23.],
[0., 0., 0., 18., 19., 20., 22., 23., 24.],
[0., 0., 0., 19., 20., 0., 23., 24., 0.],
[0., 17., 18., 0., 21., 22., 0., 25., 26.],
[17., 18., 19., 21., 22., 23., 25., 26., 27.],
[18., 19., 20., 22., 23., 24., 26., 27., 28.],
[19., 20., 0., 23., 24., 0., 27., 28., 0.],
[0., 21., 22., 0., 25., 26., 0., 29., 30.],
[21., 22., 23., 25., 26., 27., 29., 30., 31.],
[22., 23., 24., 26., 27., 28., 30., 31., 32.],
[23., 24., 0., 27., 28., 0., 31., 32., 0.],
[0., 25., 26., 0., 29., 30., 0., 0., 0.],
[25., 26., 27., 29., 30., 31., 0., 0., 0.],
[26., 27., 28., 30., 31., 32., 0., 0., 0.],
[27., 28., 0., 31., 32., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 33., 34., 0., 37., 38.],
[0., 0., 0., 33., 34., 35., 37., 38., 39.],
[0., 0., 0., 34., 35., 36., 38., 39., 40.],
[0., 0., 0., 35., 36., 0., 39., 40., 0.],
[0., 33., 34., 0., 37., 38., 0., 41., 42.],
[33., 34., 35., 37., 38., 39., 41., 42., 43.],
[34., 35., 36., 38., 39., 40., 42., 43., 44.],
[35., 36., 0., 39., 40., 0., 43., 44., 0.],
[0., 37., 38., 0., 41., 42., 0., 45., 46.],
[37., 38., 39., 41., 42., 43., 45., 46., 47.],
[38., 39., 40., 42., 43., 44., 46., 47., 48.],
[39., 40., 0., 43., 44., 0., 47., 48., 0.],
[0., 41., 42., 0., 45., 46., 0., 0., 0.],
[41., 42., 43., 45., 46., 47., 0., 0., 0.],
[42., 43., 44., 46., 47., 48., 0., 0., 0.],
[43., 44., 0., 47., 48., 0., 0., 0., 0.]]]])
print(data.shape)
n, i = 4, 0
indices = [index for index in range(data.shape[2]) if index % n == i]
print(data[0, 0, indices])
对于这些张量的组合,您可以尝试使用:
n = 4
result = []
for i in range(n):
indices = [index for index in range(data.shape[2]) if index % n == i]
result.append(data[0, 0, indices])
final = torch.stack(result, dim=0)
您可以通过首先构建包含所选行的张量,然后使用 torch.gather
到 assemble 最终张量来实现此目的。
假设我们两个 lists I
和 N
分别包含 i
和 n
的值:
I = [0, 1, 2, 3]
N = [4, 4, 4, 4]
首先我们构建索引张量:
>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
tensor([[[ 0],
[ 4],
[ 8],
[12]],
[[ 1],
[ 5],
[ 9],
[13]],
[[ 2],
[ 6],
[10],
[14]],
[[ 3],
[ 7],
[11],
[15]]])
然后需要进行一些扩展和重塑:
>>> index_ = index[None].flatten(1,2).expand(X.size(0), -1, X.size(-1))
tensor([[[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 4, 4, 4, 4],
[ 8, 8, 8, 8, 8, 8, 8, 8, 8],
[12, 12, 12, 12, 12, 12, 12, 12, 12],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 5, 5, 5, 5, 5, 5, 5, 5, 5],
[ 9, 9, 9, 9, 9, 9, 9, 9, 9],
[13, 13, 13, 13, 13, 13, 13, 13, 13],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 6, 6, 6, 6, 6, 6, 6, 6, 6],
[10, 10, 10, 10, 10, 10, 10, 10, 10],
[14, 14, 14, 14, 14, 14, 14, 14, 14],
[ 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 7, 7, 7, 7, 7, 7, 7, 7, 7],
[11, 11, 11, 11, 11, 11, 11, 11, 11],
[15, 15, 15, 15, 15, 15, 15, 15, 15]]])
根据经验,我们希望 index_
具有与 X
相同的维数。
现在我们可以应用 torch.gather
并重塑最终形式:
>>> X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)
tensor([[[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]],
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]],
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]],
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]]])
此方法可以扩展到批量张量:
>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
>>> index_ = index[None,None].flatten(2,3).expand(X.size(0), X.size(1), -1, X.size(-1))
>>> X.gather(2, index_).reshape(*X.shape[:2], *index.shape[:2], -1)
我希望以某种方式对张量进行索引。
假设我的数据,X形张量(1, 3, 16, 9)
是
tensor([[[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]],
[[ 0., 0., 0., 0., 17., 18., 0., 21., 22.],
[ 0., 0., 0., 17., 18., 19., 21., 22., 23.],
[ 0., 0., 0., 18., 19., 20., 22., 23., 24.],
[ 0., 0., 0., 19., 20., 0., 23., 24., 0.],
[ 0., 17., 18., 0., 21., 22., 0., 25., 26.],
[17., 18., 19., 21., 22., 23., 25., 26., 27.],
[18., 19., 20., 22., 23., 24., 26., 27., 28.],
[19., 20., 0., 23., 24., 0., 27., 28., 0.],
[ 0., 21., 22., 0., 25., 26., 0., 29., 30.],
[21., 22., 23., 25., 26., 27., 29., 30., 31.],
[22., 23., 24., 26., 27., 28., 30., 31., 32.],
[23., 24., 0., 27., 28., 0., 31., 32., 0.],
[ 0., 25., 26., 0., 29., 30., 0., 0., 0.],
[25., 26., 27., 29., 30., 31., 0., 0., 0.],
[26., 27., 28., 30., 31., 32., 0., 0., 0.],
[27., 28., 0., 31., 32., 0., 0., 0., 0.]],
[[ 0., 0., 0., 0., 33., 34., 0., 37., 38.],
[ 0., 0., 0., 33., 34., 35., 37., 38., 39.],
[ 0., 0., 0., 34., 35., 36., 38., 39., 40.],
[ 0., 0., 0., 35., 36., 0., 39., 40., 0.],
[ 0., 33., 34., 0., 37., 38., 0., 41., 42.],
[33., 34., 35., 37., 38., 39., 41., 42., 43.],
[34., 35., 36., 38., 39., 40., 42., 43., 44.],
[35., 36., 0., 39., 40., 0., 43., 44., 0.],
[ 0., 37., 38., 0., 41., 42., 0., 45., 46.],
[37., 38., 39., 41., 42., 43., 45., 46., 47.],
[38., 39., 40., 42., 43., 44., 46., 47., 48.],
[39., 40., 0., 43., 44., 0., 47., 48., 0.],
[ 0., 41., 42., 0., 45., 46., 0., 0., 0.],
[41., 42., 43., 45., 46., 47., 0., 0., 0.],
[42., 43., 44., 46., 47., 48., 0., 0., 0.],
[43., 44., 0., 47., 48., 0., 0., 0., 0.]]]]
我想要将 (row_index % n) == i
(比如 n = 4
和 i = 0 to 3
)保存在另一个张量 Y
.
例如,对于数据X[0][0]
:
[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]
我想要一个包含以下数据的张量,它基本上是 row_index % 4 == 0
(此处 i = 0
)所在行的集合:
[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]]
同样,i = 1
、row_index % 4 == i
看起来像:
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]]
当 i = 2
、row_index % 4 == i
:
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]]
当 i = 3
、row_index % 4 == i
:
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]
我已经尝试对其进行硬编码,但当数据变大并且大小变得动态时,它似乎并不实用,我认为会有更好的方法来实现它。
temp0 = data[0][0][0][:]
temp1 = data[0][0][4][:]
temp2 = data[0][0][8][:]
temp3 = data[0][0][12][:]
temp = torch.stack([temp0,temp1,temp2,temp3],dim = 0)
另外,如果结果能以一个张量返回就好了:
tensor Y = ([[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]],
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]],
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]],
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]])
首先,要获得每个爱国,你可以试试这个:
import torch
data = torch.tensor([[[[0., 0., 0., 0., 1., 2., 0., 5., 6.],
[0., 0., 0., 1., 2., 3., 5., 6., 7.],
[0., 0., 0., 2., 3., 4., 6., 7., 8.],
[0., 0., 0., 3., 4., 0., 7., 8., 0.],
[0., 1., 2., 0., 5., 6., 0., 9., 10.],
[1., 2., 3., 5., 6., 7., 9., 10., 11.],
[2., 3., 4., 6., 7., 8., 10., 11., 12.],
[3., 4., 0., 7., 8., 0., 11., 12., 0.],
[0., 5., 6., 0., 9., 10., 0., 13., 14.],
[5., 6., 7., 9., 10., 11., 13., 14., 15.],
[6., 7., 8., 10., 11., 12., 14., 15., 16.],
[7., 8., 0., 11., 12., 0., 15., 16., 0.],
[0., 9., 10., 0., 13., 14., 0., 0., 0.],
[9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 17., 18., 0., 21., 22.],
[0., 0., 0., 17., 18., 19., 21., 22., 23.],
[0., 0., 0., 18., 19., 20., 22., 23., 24.],
[0., 0., 0., 19., 20., 0., 23., 24., 0.],
[0., 17., 18., 0., 21., 22., 0., 25., 26.],
[17., 18., 19., 21., 22., 23., 25., 26., 27.],
[18., 19., 20., 22., 23., 24., 26., 27., 28.],
[19., 20., 0., 23., 24., 0., 27., 28., 0.],
[0., 21., 22., 0., 25., 26., 0., 29., 30.],
[21., 22., 23., 25., 26., 27., 29., 30., 31.],
[22., 23., 24., 26., 27., 28., 30., 31., 32.],
[23., 24., 0., 27., 28., 0., 31., 32., 0.],
[0., 25., 26., 0., 29., 30., 0., 0., 0.],
[25., 26., 27., 29., 30., 31., 0., 0., 0.],
[26., 27., 28., 30., 31., 32., 0., 0., 0.],
[27., 28., 0., 31., 32., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 33., 34., 0., 37., 38.],
[0., 0., 0., 33., 34., 35., 37., 38., 39.],
[0., 0., 0., 34., 35., 36., 38., 39., 40.],
[0., 0., 0., 35., 36., 0., 39., 40., 0.],
[0., 33., 34., 0., 37., 38., 0., 41., 42.],
[33., 34., 35., 37., 38., 39., 41., 42., 43.],
[34., 35., 36., 38., 39., 40., 42., 43., 44.],
[35., 36., 0., 39., 40., 0., 43., 44., 0.],
[0., 37., 38., 0., 41., 42., 0., 45., 46.],
[37., 38., 39., 41., 42., 43., 45., 46., 47.],
[38., 39., 40., 42., 43., 44., 46., 47., 48.],
[39., 40., 0., 43., 44., 0., 47., 48., 0.],
[0., 41., 42., 0., 45., 46., 0., 0., 0.],
[41., 42., 43., 45., 46., 47., 0., 0., 0.],
[42., 43., 44., 46., 47., 48., 0., 0., 0.],
[43., 44., 0., 47., 48., 0., 0., 0., 0.]]]])
print(data.shape)
n, i = 4, 0
indices = [index for index in range(data.shape[2]) if index % n == i]
print(data[0, 0, indices])
对于这些张量的组合,您可以尝试使用:
n = 4
result = []
for i in range(n):
indices = [index for index in range(data.shape[2]) if index % n == i]
result.append(data[0, 0, indices])
final = torch.stack(result, dim=0)
您可以通过首先构建包含所选行的张量,然后使用 torch.gather
到 assemble 最终张量来实现此目的。
假设我们两个 lists I
和 N
分别包含 i
和 n
的值:
I = [0, 1, 2, 3]
N = [4, 4, 4, 4]
首先我们构建索引张量:
>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
tensor([[[ 0],
[ 4],
[ 8],
[12]],
[[ 1],
[ 5],
[ 9],
[13]],
[[ 2],
[ 6],
[10],
[14]],
[[ 3],
[ 7],
[11],
[15]]])
然后需要进行一些扩展和重塑:
>>> index_ = index[None].flatten(1,2).expand(X.size(0), -1, X.size(-1))
tensor([[[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 4, 4, 4, 4, 4, 4, 4, 4, 4],
[ 8, 8, 8, 8, 8, 8, 8, 8, 8],
[12, 12, 12, 12, 12, 12, 12, 12, 12],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 5, 5, 5, 5, 5, 5, 5, 5, 5],
[ 9, 9, 9, 9, 9, 9, 9, 9, 9],
[13, 13, 13, 13, 13, 13, 13, 13, 13],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 6, 6, 6, 6, 6, 6, 6, 6, 6],
[10, 10, 10, 10, 10, 10, 10, 10, 10],
[14, 14, 14, 14, 14, 14, 14, 14, 14],
[ 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 7, 7, 7, 7, 7, 7, 7, 7, 7],
[11, 11, 11, 11, 11, 11, 11, 11, 11],
[15, 15, 15, 15, 15, 15, 15, 15, 15]]])
根据经验,我们希望 index_
具有与 X
相同的维数。
现在我们可以应用 torch.gather
并重塑最终形式:
>>> X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)
tensor([[[[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.]],
[[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.]],
[[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.]],
[[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]]]])
此方法可以扩展到批量张量:
>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
>>> index_ = index[None,None].flatten(2,3).expand(X.size(0), X.size(1), -1, X.size(-1))
>>> X.gather(2, index_).reshape(*X.shape[:2], *index.shape[:2], -1)