Pytorch 为模型分配固定参数

Pytorch Assigning fixed parameter to the model

我只对 2 个具有特定权重的卷积层之后的特征图感兴趣。

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 2), stride=1, padding=1, bias=False),
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        with torch.no_grad():
            weights1 = torch.tensor([[0.2390, 0.1593], [0.5377, 0]])
            self.layer1.weight = nn.Parameter(weights1, requires_grad=False)

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 2), stride=1, padding=1, bias=False),
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        with torch.no_grad():
            weights2 = torch.tensor([[-0.2390, -0.3585], [-0.5377, 0.2390]])
            self.layer2.weight = nn.Parameter(weights2, requires_grad=False)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

由于实际传感器实现的限制,我不得不像上面那样使用固定权重。 问题是输出不稳定。

>>>list(model.parameters())
[Parameter containing:
tensor([[0.2390, 0.1593],
        [0.5377, 0.0000]]), Parameter containing:
tensor([[[[-0.2701,  0.1602],
          [-0.0056, -0.0924]]]], requires_grad=True), Parameter containing:
tensor([[-0.2390, -0.3585],
        [-0.5377,  0.2390]]), Parameter containing:
tensor([[[[-0.0287,  0.2864],
          [ 0.3319, -0.3913]]]], requires_grad=True)]

以上是模型参数的结果,可以看到还有其他参数。

你知道如何修改参数吗?

您没有在不正确的对象上访问 weight 属性:self.layer1self.layer2 不是 nn.Conv2d 实例,它们是 nn.Sequential层。这样做实际上是将两个新张量(固定张量)注册到您的模块,添加了由 nn.Conv2d 层实例化的两个张量参数。

您应该将固定参数分别分配给 self.layer1[0] 和(self.layer2[0]):

self.layer1[0].weight = nn.Parameter(weights1, requires_grad=False)
# and
self.layer2[0].weight = nn.Parameter(weights2, requires_grad=False)

那么.parameters()会生成两个tensor参数:

>>> list
[Parameter containing:
 tensor([[0.2390, 0.1593],
         [0.5377, 0.0000]]), Parameter containing:
 tensor([[-0.2390, -0.3585],
         [-0.5377,  0.2390]])]