在 pytorch 损失函数中修改数据的最快、最好(最快)方式?

Fastest, best (fastest) way to modify data in in a pytorch loss function?

我想尝试为 4 通道图像数据创建修改后的损失函数。

最好的拆分方式是什么torch.Size([64, 4, 128, 128])

torch.Size([64, 3, 128, 128]) torch.Size([64, 1, 128, 128])

您可以对第二个轴进行切片并提取两个张量:

>>> a, b = x[:, :3], x[:, 3:]
>>> a.shape, b.shape
(64, 3, 128, 128), (64, 1, 128, 128)

或者,您可以在第一个维度上应用 torch.split

>>> a, b = x.split(3, dim=1)
>>> a.shape, b.shape
(64, 3, 128, 128), (64, 1, 128, 128)

我能够使用拆分功能自行解决此问题。

给定一个基于图像的张量,例如:torch.Size([64, 4, 128, 128])

您可以在 dim 1 上拆分并给定静态长度。

self.E1 = torch.split(self.E, 3, 1)
print(self.E1[0].shape);
print(self.E1[1].shape);

给出:

torch.Size([64, 4, 128, 128])
torch.Size([64, 3, 128, 128])
torch.Size([64, 1, 128, 128])