Pytorch:通过提取张量行进行一系列串联的计算和内存效率最高的方法?

Pytorch: Most computationally and memory efficient way to make a series of concatenations from extracting tensor rows?

说这是我的样本张量

sample = torch.tensor(
    [[2, 7, 3, 1, 1],
        [9, 5, 8, 2, 5],
        [0, 4, 0, 1, 4],
        [5, 4, 9, 0, 0]]
)

我想要一个新的张量,它将包含来自样本张量的 2 行的串联。

所以我有一个张量,其中包含我想为新张量连接成一行的行号对

cat_indices = torch.tensor([[0, 1], [1, 2], [0, 2], [2, 3]])

我目前使用的方法是这样的

torch.cat((sample[cat_indices[:,0]], sample[cat_indices[:,1]]), dim=1)

给出了想要的结果

tensor([[2, 7, 3, 1, 1, 9, 5, 8, 2, 5],
        [9, 5, 8, 2, 5, 0, 4, 0, 1, 4],
        [2, 7, 3, 1, 1, 0, 4, 0, 1, 4],
        [0, 4, 0, 1, 4, 5, 4, 9, 0, 0]])

这是执行此操作的最内存和计算效率最高的方法吗?我不确定,因为我正在对 cat_indices 进行两次调用,然后进行串联操作。

我觉得应该有一种方法可以通过某种视图来做到这一点。也许高级索引。我试过 sample[cat_indices[:,0], cat_indices[:,1]]sample[cat_indices[0], cat_indices[1]] 之类的东西,但我无法使视图正确显示。

你的应该是相当快的。另一种选择是

sample[cat_indices].reshape(cat_indices.shape[0],-1)

您必须对计算机的性能进行基准测试才能确定哪个更好。