合成具有全连接层的 1x1 卷积层

Synthesizing 1x1 convolution layer with fully connected layers

我正在尝试合成具有完全连接层的 1x1 卷积层。这意味着一个完全连接的神经网络决定了一个 1x1 卷积层的参数。这是我的做法。

class Network(nn.Module):
def __init__(self, len_input, num_kernels):
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels),
        nn.ReLU()
    )

    self.synthesized_conv = nn.Conv2d(in_channels=3, out_channels=num_kernels, bias=False, kernel_size=1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1)
    )

def forward(self, x1, img):
    x = self.input_layer(x1.float())
    with torch.no_grad():
        self.synthesized_conv.weight = nn.Parameter(x.reshape_as(self.synthesized_conv.weight))
    generated = self.conv_layer(self.synthesized_conv(img))
    return generated
  

在那里你可以看到我正在初始化一个名为“synthesized_conv”的 1x1 conv 层,并尝试用一个名为“self.input_layers”的完全连接的网络输出替换它的参数-参考。然而,梯度看起来不像是流过全连接网络,而只是流过卷积层。全连接层的参数直方图如下所示:

这个直方图是那些完全连接的部分根本没有学习的有力指标。这很可能是全连接网络输出更新卷积参数的弊端。有人可以帮助我如何在不破坏 autograd 图的情况下做到这一点吗?

问题是您一次又一次地重新定义模型的 weight 属性。更直接的解决方案是使用功能方法, torch.nn.functional.conv2d:

class Network(nn.Module):
  def __init__(self, len_input, num_kernels):
    super().__init__()
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels * 3),
        nn.ReLU())

    self.synthesized_conv = nn.Conv2d(
        in_channels=3, out_channels=num_kernels, kernel_size=1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))

  def forward(self, x1, img):
    x = self.input_layers(x1.float())
    w = x.reshape_as(self.synthesized_conv.weight)
    generated = F.conv2d(img, w)
    return generated

此外,我相信您的 input_layers 总共必须输出 num_kernels * 3 个组件,因为您的合成卷积总共有三个通道。

这是一个测试示例:

>>> model = Network(10,3)
>>> out = model(torch.rand(1,10), torch.rand(1,3,16,16))
>>> out.shape
(torch.Size([1, 3, 16, 16]), <ThnnConv2DBackward at 0x7fe5d8e41450>)

当然,synthesized_conv 的参数永远不会改变,因为它们永远不会被用来推断输出。您可以完全删除 self.synthesized_conv

class Network(nn.Module):
  def __init__(self, len_input, num_kernels):
    super().__init__()
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels*3),
        nn.ReLU())

    self.syn_conv_shape = (num_kernels, 3, 1, 1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))

  def forward(self, x1, img):
    x = self.input_layers(x1.float())
    generated = F.conv2d(img, x.reshape(self.syn_conv_shape))
    return generated