从 pytorch 张量中按索引删除一行
Delete a row by index from pytorch tensor
我有一个大小为 torch.Size([4, 3, 2])
的 pytorch 张量
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.9600, 0.5381],
[0.5758, 0.8458],
[0.6342, 0.5872]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
我想删除第 2 行,使张量变为 torch.Size([3, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
如何删除 3D 张量的第 n 行?
下面的操作选择了除一个“行”之外的所有行:
import torch
torch.manual_seed(2021)
row = 2
x = torch.rand((4, 3, 2))
new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]
print(new_x.shape)
# >>> torch.Size([3, 3, 2])
print(x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.0443, 0.9628],
# > [0.2943, 0.0992],
# > [0.8096, 0.0169]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
print(new_x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))
print(x.shape)
>>> torch.Size([3, 3, 2])
目前,我有这个缓慢的方法(它对我有用,因为我不经常调用这个函数)。
def delete_row_tensor(a, del_row, device):
n = a.cpu().detach().numpy()
n = np.delete(n, del_row, 0)
n = torch.from_numpy(n).to(device)
return n
我还在寻找高效的 torch 方法。
我有一个大小为 torch.Size([4, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.9600, 0.5381],
[0.5758, 0.8458],
[0.6342, 0.5872]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
我想删除第 2 行,使张量变为 torch.Size([3, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
如何删除 3D 张量的第 n 行?
下面的操作选择了除一个“行”之外的所有行:
import torch
torch.manual_seed(2021)
row = 2
x = torch.rand((4, 3, 2))
new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]
print(new_x.shape)
# >>> torch.Size([3, 3, 2])
print(x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.0443, 0.9628],
# > [0.2943, 0.0992],
# > [0.8096, 0.0169]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
print(new_x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))
print(x.shape)
>>> torch.Size([3, 3, 2])
目前,我有这个缓慢的方法(它对我有用,因为我不经常调用这个函数)。
def delete_row_tensor(a, del_row, device):
n = a.cpu().detach().numpy()
n = np.delete(n, del_row, 0)
n = torch.from_numpy(n).to(device)
return n
我还在寻找高效的 torch 方法。