Pytorch:通过提取张量行进行一系列串联的计算和内存效率最高的方法?
Pytorch: Most computationally and memory efficient way to make a series of concatenations from extracting tensor rows?
说这是我的样本张量
sample = torch.tensor(
[[2, 7, 3, 1, 1],
[9, 5, 8, 2, 5],
[0, 4, 0, 1, 4],
[5, 4, 9, 0, 0]]
)
我想要一个新的张量,它将包含来自样本张量的 2 行的串联。
所以我有一个张量,其中包含我想为新张量连接成一行的行号对
cat_indices = torch.tensor([[0, 1], [1, 2], [0, 2], [2, 3]])
我目前使用的方法是这样的
torch.cat((sample[cat_indices[:,0]], sample[cat_indices[:,1]]), dim=1)
给出了想要的结果
tensor([[2, 7, 3, 1, 1, 9, 5, 8, 2, 5],
[9, 5, 8, 2, 5, 0, 4, 0, 1, 4],
[2, 7, 3, 1, 1, 0, 4, 0, 1, 4],
[0, 4, 0, 1, 4, 5, 4, 9, 0, 0]])
这是执行此操作的最内存和计算效率最高的方法吗?我不确定,因为我正在对 cat_indices
进行两次调用,然后进行串联操作。
我觉得应该有一种方法可以通过某种视图来做到这一点。也许高级索引。我试过 sample[cat_indices[:,0], cat_indices[:,1]]
或 sample[cat_indices[0], cat_indices[1]]
之类的东西,但我无法使视图正确显示。
你的应该是相当快的。另一种选择是
sample[cat_indices].reshape(cat_indices.shape[0],-1)
您必须对计算机的性能进行基准测试才能确定哪个更好。
说这是我的样本张量
sample = torch.tensor(
[[2, 7, 3, 1, 1],
[9, 5, 8, 2, 5],
[0, 4, 0, 1, 4],
[5, 4, 9, 0, 0]]
)
我想要一个新的张量,它将包含来自样本张量的 2 行的串联。
所以我有一个张量,其中包含我想为新张量连接成一行的行号对
cat_indices = torch.tensor([[0, 1], [1, 2], [0, 2], [2, 3]])
我目前使用的方法是这样的
torch.cat((sample[cat_indices[:,0]], sample[cat_indices[:,1]]), dim=1)
给出了想要的结果
tensor([[2, 7, 3, 1, 1, 9, 5, 8, 2, 5],
[9, 5, 8, 2, 5, 0, 4, 0, 1, 4],
[2, 7, 3, 1, 1, 0, 4, 0, 1, 4],
[0, 4, 0, 1, 4, 5, 4, 9, 0, 0]])
这是执行此操作的最内存和计算效率最高的方法吗?我不确定,因为我正在对 cat_indices
进行两次调用,然后进行串联操作。
我觉得应该有一种方法可以通过某种视图来做到这一点。也许高级索引。我试过 sample[cat_indices[:,0], cat_indices[:,1]]
或 sample[cat_indices[0], cat_indices[1]]
之类的东西,但我无法使视图正确显示。
你的应该是相当快的。另一种选择是
sample[cat_indices].reshape(cat_indices.shape[0],-1)
您必须对计算机的性能进行基准测试才能确定哪个更好。