CNN模型学不好

CNN model is not learn well

我开始在 PyTorch 中学习 CNN 实现,我尝试构建 CNN 来处理 4 类 从 0 到 3 的灰度图像。我的开始精度约为 0.55。我得到的最大准确度是 ~ 0.683%。

我用不同的 lr 和 batch_size 值尝试了 SGD 和 Adam 优化器,但准确率仍然很低。

我使用数据增强来创建更多样本,大约 4k。

我无法进一步提高准确性,想知道是否可以就我需要更改 CNN 结构以提高准确性获得一些建议。 损失开始于:损失:[1.497]附近然后减少:损失:[0.001]然后围绕这个值上下波动。

我花了时间阅读类似的问题,但没有运气。 我正在为我的 loss_fn 使用 nn.CrossEntropyLoss()。我不对密集层使用 softmax。

这是CNN模型的总结:

-------------------------------------------------------------
        Layer (type)               Output Shape         Param #
=============================================================
            Conv2d-1         [-1, 32, 128, 128]             320
              ReLU-2         [-1, 32, 128, 128]               0
       BatchNorm2d-3         [-1, 32, 128, 128]              64
         MaxPool2d-4           [-1, 32, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]          18,496
              ReLU-6           [-1, 64, 64, 64]               0
       BatchNorm2d-7           [-1, 64, 64, 64]             128
         MaxPool2d-8           [-1, 64, 32, 32]               0
            Conv2d-9          [-1, 128, 32, 32]          73,856
             ReLU-10          [-1, 128, 32, 32]               0
      BatchNorm2d-11          [-1, 128, 32, 32]             256
        MaxPool2d-12          [-1, 128, 16, 16]               0
          Flatten-13                [-1, 32768]               0
           Linear-14                  [-1, 512]      16,777,728
             ReLU-15                  [-1, 512]               0
          Dropout-16                  [-1, 512]               0
           Linear-17                    [-1, 4]           2,052
============================================================

非常感谢您的帮助。

火车集中有多少张图片?测试集?图片的尺寸是多少?你如何看待图像分类的难度?你觉得应该简单还是困难?

根据你的数字,你过度拟合了,因为你的损失接近于 0(这意味着没有什么会反向传播到权重,即你的模型不会再改变)和你的 68.3%(这是一个错字吧?)来自测试集(我想)。所以你训练网络没有任何问题,这是一个好点。

然后你可以在网上搜索反过拟合的方法,这里有一些“经典”的可能性: - 你可以提高 dropout 参数
- 放置一些正则化器(L1 或 L2)来约束学习
- 使用验证集提前停止
- 使用经典 and/or 轻量级卷积网络 (resnet,inception) with/without 预训练权重。后者还取决于您的图像类型(自然、生物医学……)
- ...或多或少难以实施

此外,从技术上讲,您使用的是 softmax 层,因为它包含在 pytorch 的交叉熵损失中。