如何获取层中与 PyTorch 中的状态字典匹配的特征值?

How to get the value of a feature in a layer that match a the state dict in PyTorch?

我有一些cnn,我想从state dict中获取对应于某个键的某个中间层的值。 这怎么可能呢? 谢谢。

我认为您需要创建一个新的 class 来重新定义给定模型的前向传递。但是,您很可能需要创建有关模型架构的代码。你可以在这里找到一个例子:

class extract_layers():

    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        module = self.model._modules[self.target_layer]

        # get output of the desired layer
        features = module(x)

        # get output of the whole model
        x = self.model(x)

        return x, features


model = models.vgg19(pretrained=True)
target_layer = 'features'
extractor = extract_layers(model, target_layer)

image = Variable(torch.randn(1, 3, 244, 244))
x, features = extractor(image)

在这种情况下,我使用的是 pytorch models zoo 中提供的预定义 vgg19 网络。该网络的层结构分为两个模块,features 用于卷积部分,classifier 用于全连接部分。在这种情况下,由于 features 包裹了网络的所有卷积层,所以很简单。如果您的架构有多个名称不同的层,您将需要使用类似这样的东西来存储它们的输出:

 for name, module in self.model._modules.items():
    x = module(x)  # forward the module individually
    if name in self.target_layer:
        features = x  # store the output of the desired layer

此外,您应该记住,您需要重塑连接卷积部分和全连接部分的层的输出。如果你知道那个层的名称,应该很容易做到。