如何在训练 CNN 时辨别哪个图像生成了特定的特征图?

How to discern which image generated a particular feature map, while training CNNs?

假设我将 3 张灰度图像输入 CNN,其组合形状为 3、28、28。此过程将为每个图像生成多个特征图。如何识别哪个特征图对应于特定图像。

这是一些代码 -

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        print("Shape of x = ", x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        print("Shape of x = ", x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

foo = torch.randn(3,1, 28, 28)

foo_cnn = net(foo)

例如,第一个卷积从 3 个图像生成 6 个特征图。有没有办法让我识别哪个特征图属于哪个图像,以便我可以对其进行一些操作。

为了区分哪个图像生成了哪些卷积特征图,必须将不同的输入图像拆分为批次维度 (#images=#batches),这样在应用任何卷积层时,它们会分别应用于每个图像,而不是不同输入图像的加权和,如果它们被分割成 channel/depth 维度,就会出现这种情况。

现在您不是将 3 张图像输入模型(在 pytorch 的眼中);这将要求输入的形状为:(3, 1, 28, 28) 用于灰度图像,(3, 3, 28, 28) 用于 RGB 图像。你正在做的是(在某种意义上)将 3 个图像连接到深度维度中,从而产生形状:(1, 3, 28, 28),因此 6 个输出特征图不能归因于特定图像(加权组合3,因为它们在深度维度上)。

因此,将输入重塑为 (3, 1, 28, 28) 并将 conv1 更改为 (1, 6, 5) 将导致以下输出:(3, 6, 12, 12) 因此,1st 1st batch (of the output) 中的 6 个特征图对应于 batch (of input) 中的第一张图像,而 2nd 6 个特征图对应于 2nd 中的图像批次等等。