如何动态索引pytorch中的张量?
How to dynamically index the tensor in pytorch?
比如我得到张量:
tensor = torch.rand(12, 512, 768)
我得到了一个索引列表,说它是:
[0,2,3,400,5,32,7,8,321,107,100,511]
我希望 select 给定索引列表,维度 2 上的 512 个元素中的 1 个元素。然后张量的大小将变为(12, 1, 768)
。
有办法吗?
是的,可以直接使用索引切片,然后使用torch.unsqueeze()
将2D张量提升为3D:
# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
...: sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
或者,如果您想要更简洁的代码并且不想使用 torch.unsqueeze()
,则使用:
In [11]: for idx in idx_list:
...: sampled_tensor = tensor[:, [idx], :]
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
注意:如果您只想对 idx_list
中的一个 idx
执行此切片,则无需使用 for
循环
还有一种方法只使用 PyTorch 并使用 indexing 和 torch.split
:
避免循环
tensor = torch.rand(12, 512, 768)
# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list)
# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)
当您调用 tensor[:, idx_tensor, :]
时,您将得到一个形状为:
(12, len_of_idx_list, 768)
的张量。
第二个维度取决于您的索引数量。
使用 torch.split
这个张量被分成一个 list 形状的张量:(12, 1, 768)
.
所以最后 list_of_tensors
包含形状的张量:
[torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768])]
您可以使用tensor.gather()
函数:
tensor = torch.rand(12, 512, 768)
ind = torch.tensor([0,2,3,400,5,32,7,8,321,107,100,511]).unsqueeze(1).unsqueeze(-1).expand(-1,-1,768) # shape (12,1,768)
tensor.gather(dim = 1, index = ind) # # shape (12,1,768)
比如我得到张量:
tensor = torch.rand(12, 512, 768)
我得到了一个索引列表,说它是:
[0,2,3,400,5,32,7,8,321,107,100,511]
我希望 select 给定索引列表,维度 2 上的 512 个元素中的 1 个元素。然后张量的大小将变为(12, 1, 768)
。
有办法吗?
是的,可以直接使用索引切片,然后使用torch.unsqueeze()
将2D张量提升为3D:
# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
...: sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
或者,如果您想要更简洁的代码并且不想使用 torch.unsqueeze()
,则使用:
In [11]: for idx in idx_list:
...: sampled_tensor = tensor[:, [idx], :]
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
注意:如果您只想对 idx_list
idx
执行此切片,则无需使用 for
循环
还有一种方法只使用 PyTorch 并使用 indexing 和 torch.split
:
tensor = torch.rand(12, 512, 768)
# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list)
# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)
当您调用 tensor[:, idx_tensor, :]
时,您将得到一个形状为:
(12, len_of_idx_list, 768)
的张量。
第二个维度取决于您的索引数量。
使用 torch.split
这个张量被分成一个 list 形状的张量:(12, 1, 768)
.
所以最后 list_of_tensors
包含形状的张量:
[torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768])]
您可以使用tensor.gather()
函数:
tensor = torch.rand(12, 512, 768)
ind = torch.tensor([0,2,3,400,5,32,7,8,321,107,100,511]).unsqueeze(1).unsqueeze(-1).expand(-1,-1,768) # shape (12,1,768)
tensor.gather(dim = 1, index = ind) # # shape (12,1,768)