如何使用 F.interpolate 调整 PyTorch 中所有 4 个维度 (NCHW) 的大小?

How to resize all 4 dimensions (NCHW) in PyTorch with F.interpolate?

我一直在尝试弄清楚如何调整张量中批次、通道、高度和宽度维度的大小。目前我可以调整 Channels、Height 和 Width 尺寸,但 Batch 尺寸保持不变。

x = torch.ones(3,4,64,64)

x = F.interpolate(x.unsqueeze(0), size=(3,4,4), mode="trilinear").squeeze(0)

x.size() # (3,3,4,4) # batch dimension has not been resized.
# I need x to be resized so that it has a size of: (1,3,4,4)


# Is this a good idea?
x = x.permute(1,0,2,3)
x = F.interpolate(x.unsqueeze(0), size=(1, x.size(2), x.size(3)), mode="trilinear").squeeze(0)
x = x.permute(1,0,2,3)

x.size() # (1,3,4,4)

我应该置换张量以调整批处理维度的大小吗?或者以某种方式遍历它?

这似乎让我可以调整批次维度的大小:

x = x.permute(1,0,2,3)
x = F.interpolate(x.unsqueeze(0), size=(1, x.size(2), x.size(3)), mode="trilinear").squeeze(0)
x = x.permute(1,0,2,3)