根据给定维度从 PyTorch 张量中拆分和提取值

Split and extract values from a PyTorch tensor according to given dimensions

我有一个 A 大小 torch.Size([32, 32, 3, 3]) 的张量,我想拆分它并从中提取一个大小 torch.Size([16, 16, 3, 3]) 的张量 B。张量可以是 1d 或 4d,并且必须根据给定的新张量维度进行拆分。我已经能够生成目标维度,但我无法从源张量中拆分和提取值。我试过 torch.narrow 但它只需要 3 个参数,在很多情况下我需要 4 个。 torch.split 将 dim 作为 int,因为张量仅沿一维分裂。但我想沿多个维度拆分它。

您有多种选择:

  • 多次使用.split
  • 多次使用.narrow
  • 使用切片

例如:

t = torch.rand(32, 32, 3, 3)

t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])

t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])

t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])