填充在 PyTorch 中的工作原理

How padding works in PyTorch

通常情况下,如果我很好地理解 Conv2D 层的 PyTorch 实现,填充参数会将卷积图像的形状扩展到输入的所有四个边都为零。因此,如果我们有一个形状为 (6,6) 的图像并设置 padding = 2strides = 2kernel = (5,5),输出将是一个形状为 (1,1) 的图像。然后,padding = 2 将填充零(向上 2 个,向下 2 个,向左 2 个,向右 2 个),生成形状为 (5,5)

的卷积图像

但是当运行以下脚本时:

import torch
from torch import nn
x = torch.ones(1,1,6,6)
y = nn.Conv2d(in_channels= 1, out_channels=1, 
              kernel_size= 5, stride = 2, 
              padding = 2,)(x)

我得到了以下输出:

y.shape
==> torch.Size([1, 1, 3, 3]) ("So shape of convolved image = (3,3) instead of (5,5)")

y[0][0]
==> tensor([[0.1892, 0.1718, 0.2627, 0.2627, 0.4423, 0.2906],
    [0.4578, 0.6136, 0.7614, 0.7614, 0.9293, 0.6835],
    [0.2679, 0.5373, 0.6183, 0.6183, 0.7267, 0.5638],
    [0.2679, 0.5373, 0.6183, 0.6183, 0.7267, 0.5638],
    [0.2589, 0.5793, 0.5466, 0.5466, 0.4823, 0.4467],
    [0.0760, 0.2057, 0.1017, 0.1017, 0.0660, 0.0411]],
   grad_fn=<SelectBackward>)

正常情况下应该填零。我很困惑。有人可以帮忙吗?

填充的是输入,而不是输出。在您的情况下,conv2d 层将在计算卷积运算之前在所有面上应用两个像素的填充。

为了便于说明,

>>> weight = torch.rand(1, 1, 5, 5)
  • 这里我们使用 padding=2:

    进行卷积
    >>> x = torch.ones(1,1,6,6)
    >>> F.conv2d(x, weight, stride=2, padding=2)
    tensor([[[[ 5.9152,  8.8923,  6.0984],
              [ 8.9397, 14.7627, 10.8613],
              [ 7.2708, 12.0152,  9.0840]]]])
    
  • 我们不使用任何填充,而是自己将其应用于输入:

    >>> x_padded = F.pad(x, (2,)*4)
    tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
    
    >>> F.conv2d(x_padded, weight, stride=2)
    tensor([[[[ 5.9152,  8.8923,  6.0984],
              [ 8.9397, 14.7627, 10.8613],
              [ 7.2708, 12.0152,  9.0840]]]])