撤消张量 dstack 并按列重新堆叠

Undoing tensor dstack and restack column-wise

我有两个张量,ab:

import torch

a = torch.tensor(([1,2],
                  [3,4],
                  [5,6],
                  [7,8]))

b = torch.tensor(([0,0],
                  [1,1],
                  [2,2],
                  [3,3]))

我可以水平或深度堆叠。

d = torch.dstack([a, b])
h = torch.hstack([a, b])

现在,是否有任何 PyTorch 函数,最好是在一行中,我可以应用到 d 以获得 h?听起来我想撤消深度堆叠,然后按列重新堆叠它们。我试过重塑和展平,但都不起作用,因为它们都破坏了值的顺序。

在你的情况下使用 torch.unbind

import torch

a = torch.tensor(([1,2],
                  [3,4],
                  [5,6],
                  [7,8]))

b = torch.tensor(([0,0],
                  [1,1],
                  [2,2],
                  [3,3]))

d = torch.dstack([a, b])
h = torch.hstack(torch.unbind(d,2)) # get h from d