Pytorch:CNN 在 torch.cat() 之后没有学到任何东西?

Pytorch: CNN don't learn anything after torch.cat()?

我尝试用这样的代码连接网络中的变量

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = x.view(x.size(0), -1)
    x= torch.cat((x,angle),1) # from here I concat it.
    x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
    x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
    x = self.fc3(x)

然后我发现我的网络什么也没学到,总是给 acc 50% 左右。所以我打印 param.grad 并且如我所料,它们都是 nan。有没有人遇到过这个东西?

我运行之前没有连接的代码,效果很好。所以我想这就是问题所在,系统不会抛出任何错误或异常。如果需要任何其他备份信息,请告诉我。

谢谢。

错误可能出在您提供的代码之外的某个地方。尝试检查您的输入中是否有 nan,并检查损失函数是否不会导致 nan。