根据给定维度从 PyTorch 张量中拆分和提取值
Split and extract values from a PyTorch tensor according to given dimensions
我有一个 A
大小 torch.Size([32, 32, 3, 3])
的张量,我想拆分它并从中提取一个大小 torch.Size([16, 16, 3, 3])
的张量 B
。张量可以是 1d 或 4d,并且必须根据给定的新张量维度进行拆分。我已经能够生成目标维度,但我无法从源张量中拆分和提取值。我试过 torch.narrow
但它只需要 3 个参数,在很多情况下我需要 4 个。 torch.split
将 dim 作为 int,因为张量仅沿一维分裂。但我想沿多个维度拆分它。
您有多种选择:
- 多次使用
.split
- 多次使用
.narrow
- 使用切片
例如:
t = torch.rand(32, 32, 3, 3)
t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])
t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])
t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])
我有一个 A
大小 torch.Size([32, 32, 3, 3])
的张量,我想拆分它并从中提取一个大小 torch.Size([16, 16, 3, 3])
的张量 B
。张量可以是 1d 或 4d,并且必须根据给定的新张量维度进行拆分。我已经能够生成目标维度,但我无法从源张量中拆分和提取值。我试过 torch.narrow
但它只需要 3 个参数,在很多情况下我需要 4 个。 torch.split
将 dim 作为 int,因为张量仅沿一维分裂。但我想沿多个维度拆分它。
您有多种选择:
- 多次使用
.split
- 多次使用
.narrow
- 使用切片
例如:
t = torch.rand(32, 32, 3, 3)
t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])
t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])
t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])