如何将 4 维 PyTorch 张量乘以 1 维张量?

How to 4 dimension PyTorch tensor multiply by 1 dimension tensor?

我正在尝试为混合训练编写函数。在这个 site 上,我找到了一些代码并适应了我以前的代码。但在原始代码中,只有一个随机变量为批次 (64) 生成。但我想要批量为每张图片随机取值。 批处理一个变量的代码:

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    mixed_y = lam * y + (1 - lam) * y[index,:]

    return mixed_x, mixed_y

x 和 y 输入来自 pytorch DataLoader。 x 输入大小:torch.Size([64, 3, 256, 256]) y 输入大小:torch.Size([64, 3474])

此代码运行良好。然后我改成这样:

def mixup_data(x, y):
    batch_size = x.size()[0]
    lam = torch.rand(batch_size)
    index = torch.randperm(batch_size)

    mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]

    return mixed_x, mixed_y

但是报错:RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

我如何理解代码的工作原理是它批量获取第一张图像并乘以 lam 张量中的第一个值(64 个值长)。我该怎么做?

您需要替换以下行:

lam = torch.rand(batch_size)

来自

lam = torch.rand(batch_size, 1, 1, 1)

使用您当前的代码,无法进行 lam[index] * x 乘法运算,因为 lam[index] 的大小为 torch.Size([64]),而 x 的大小为 torch.Size([64, 3, 256, 256])。因此,您需要将 lam[index] 的大小设置为 torch.Size([64, 1, 1, 1]) 以便它可以广播。

应对以下语句:

mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

我们可以在语句之前重塑 lam 张量。

lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

问题是相乘的两个张量的大小不匹配。我们以 lam[index] * x 为例。尺寸如下:

  • x: torch.Size([64, 3, 256, 256])
  • lam[index]: torch.Size([64])

为了将它们相乘,它们应该具有相同的大小,其中 lam[index] 每个批次的 [3, 256, 256] 使用相同的值,因为您想将该批次中的每个元素与相同的值,但每个批次不同。

lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])

.expand_as(x) 重复奇异维度,使其具有与 x 相同的大小,详情请参阅 .expand() documentation

您不需要展开张量,因为如果存在奇异维度,PyTorch 会自动为您展开。即所谓的广播:PyTorch - Broadcasting Semantics。因此,torch.Size([64, 1, 1, 1]) 的大小足以将其与 x 相乘。

lam[index].view(batch_size, 1, 1, 1) * x

这同样适用于 y,但大小为 torch.Size([64, 1]),因为 y 的大小为 torch.Size([64, 3474])

mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]

请注意,lam[index] 仅重新排列 lam 的元素,但由于您是随机创建的,因此无论您是否重新排列都没有任何区别。唯一重要的是 xy 重新排列,就像在原始代码中一样。