从 pytorch 张量中按索引删除一行

Delete a row by index from pytorch tensor

我有一个大小为 torch.Size([4, 3, 2])

的 pytorch 张量
tensor([[[0.4003, 0.2742],
     [0.9414, 0.1222],
     [0.9624, 0.3063]],

    [[0.9600, 0.5381],
     [0.5758, 0.8458],
     [0.6342, 0.5872]],

    [[0.5891, 0.9453],
     [0.8859, 0.6552],
     [0.5120, 0.5384]],

    [[0.3017, 0.9407],
     [0.4887, 0.8097],
     [0.9454, 0.6027]]])

我想删除第 2 行,使张量变为 torch.Size([3, 3, 2])

tensor([[[0.4003, 0.2742],
     [0.9414, 0.1222],
     [0.9624, 0.3063]],

    [[0.5891, 0.9453],
     [0.8859, 0.6552],
     [0.5120, 0.5384]],

    [[0.3017, 0.9407],
     [0.4887, 0.8097],
     [0.9454, 0.6027]]])

如何删除 3D 张量的第 n 行?

下面的操作选择了除一个“行”之外的所有行:

import torch

torch.manual_seed(2021)

row = 2
x = torch.rand((4, 3, 2))

new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]

print(new_x.shape)
# >>> torch.Size([3, 3, 2])

print(x)
# > tensor([[[0.1304, 0.5134],
# >          [0.7426, 0.7159],
# >          [0.5705, 0.1653]],
# > 
# >         [[0.0443, 0.9628],
# >          [0.2943, 0.0992],
# >          [0.8096, 0.0169]],
# > 
# >         [[0.8222, 0.1242],
# >          [0.7489, 0.3608],
# >          [0.5131, 0.2959]],
# > 
# >         [[0.7834, 0.7405],
# >          [0.8050, 0.3036],
# >          [0.9942, 0.5025]]])

print(new_x)
# > tensor([[[0.1304, 0.5134],
# >          [0.7426, 0.7159],
# >          [0.5705, 0.1653]],
# > 
# >         [[0.8222, 0.1242],
# >          [0.7489, 0.3608],
# >          [0.5131, 0.2959]],
# > 
# >         [[0.7834, 0.7405],
# >          [0.8050, 0.3036],
# >          [0.9942, 0.5025]]])
import torch
x = torch.randn(size=(4,3,2))

row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))

print(x.shape)
>>> torch.Size([3, 3, 2])

目前,我有这个缓慢的方法(它对我有用,因为我不经常调用这个函数)。

def delete_row_tensor(a, del_row, device):
    n = a.cpu().detach().numpy()
    n = np.delete(n, del_row, 0)
    n = torch.from_numpy(n).to(device)
    return n

我还在寻找高效的 torch 方法。