如何在pytorch的Unet中使用PNASNet5作为编码器

How to use PNASNet5 as encoder in Unet in pytorch

我想使用 PNASNet5Large 作为我的 Unet 的编码器这是我对 PNASNet5Large 的错误做法,但适用于 resnet:

class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152: #this works
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 777: #coded version for the pnasnet
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320 #this unknown for me as well


        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4
        self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

        self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        center = self.center(conv5)
        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

1) 如何获取pnasnet 有多少底部通道。它以以下方式结束:

...
 self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
                            in_channels_right=4320, out_channels_right=864)
        self.relu = nn.ReLU()
        self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
        self.dropout = nn.Dropout(0.5)
        self.last_linear = nn.Linear(4320, num_classes)

4320 是否是答案,in_channels_leftout_channels_left - 对我来说是新的东西

2) Resnet 有大约 4 个大层,我在我的 Unet arch 中使用和编码器,如何从 pnasnet 获得类似的层

我正在使用 pytorch 3.1,这是 link 到 Pnasnet directory

3) AttributeError: 'PNASNet5Large' 对象没有属性 'conv1' - 所以也没有 conv1

UPD:尝试过类似的方法但失败了

class UNetPNASNet(nn.Module): def init(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2, 预训练=假,is_deconv=假): 超级().初始化() self.num_classes = num_classes 自我。dropout_2d = dropout_2d self.encoder = PNASNet5Large() bottom_channel_nr = 4320 self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, 假)

        self.dec5  =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4  = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3  = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2  = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
        self.dec1  = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0  = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
        features = self.encoder.features(x)
        relued_features = self.encoder.relu(features)
        avg_pooled_features = self.encoder.avg_pool(relued_features)
        center = self.center(avg_pooled_features)
        dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
        dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
        dec3 = self.dec3(torch.cat([dec4, features], 1))
        dec2 = self.dec2(dec3)
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

RuntimeError:给定输入大小:(4320x4x4)。计算出的输出大小:(4320x-6x-6)。 /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63

处的输出大小太小

所以您想在 UNet 架构中使用 PNASNetLarge 而不是 ResNets 作为编码器。让我们看看 ResNets 是如何使用的。在你的 __init__:

self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
                           self.encoder.bn1,
                           self.encoder.relu,
                           self.pool)

self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4

因此,您最多使用 ResNets layer4,这是平均池化之前的最后一个块,您用于 resnet 的尺寸是 after平均汇集,因此我假设 self.conv5 = self.encoder.layer4 之后缺少 self.encoder.avgpooltorchvision.models 中 ResNet 的前向看起来像这样:

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

我猜你想为 PNASNet5Large 采用类似的解决方案(使用到平均池化层的架构)。

1) 要获得您的 PNASNet5Large 有多少个通道,您需要查看平均池化后的输出张量大小,例如向其提供一个虚拟张量。另请注意,虽然 ResNet 通常使用输入大小 (batch_size, 3, 224, 224),但 PNASNetLarge 使用 (batch_size, 3, 331, 331).

m = PNASNet5Large()
x1 = torch.randn(1, 3, 331, 331)
m.avg_pool(m.features(x1)).size()
torch.Size([1, 4320, 1, 1])

因此,是的,bottom_channel_nr=4320 用于您的 PNASNet。

2) 由于架构完全不同,您需要修改UNet__init__forward。如果您决定使用 PNASNet,我建议您制作一个新的 class:

class UNetPNASNet(nn.Module):
    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                     pretrained=False, is_deconv=False):
            super().__init__()
            self.num_classes = num_classes
            self.dropout_2d = dropout_2d
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320
            self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

            self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
            self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
            self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                       is_deconv)
            self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0 = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

        def forward(self, x):
            features = self.encoder.features(x)
            relued_features = self.encoder.relu(features)
            avg_pooled_features = self.encoder.avg_pool(relued_features)
            center = self.center(avg_pooled_features)
            dec5 = self.dec5(torch.cat([center, conv5], 1))
            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))

3) PNASNet5Large 确实没有 conv1 属性。您可以通过

查看
'conv1' in list(m.modules())
False