RuntimeError: Found dtype Double but expected Float - PyTorch

RuntimeError: Found dtype Double but expected Float - PyTorch

我是 pytorch 的新手,我正在使用强化学习为时间序列研究 DQN,我需要对时间序列和一些传感器读数进行复杂的观察,所以我合并了两个神经网络,我不确定这是否是是什么毁了我的 loss.backward 或其他东西。 我知道有多个标题相同的问题,但 none 对我有用,也许我遗漏了什么。
首先,这是我的网络:

class DQN(nn.Module):
  def __init__(self, list_shape, score_shape, n_actions):
    super(DQN, self).__init__()

    self.FeatureList =  nn.Sequential(
            nn.Conv1d(list_shape[1], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
    
    self.FeatureScore = nn.Sequential(
            nn.Linear(score_shape[1], 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
    
    t_list_test = torch.zeros(list_shape)
    t_score_test = torch.zeros(score_shape)
    merge_shape = self.FeatureList(t_list_test).shape[1] + self.FeatureScore(t_score_test).shape[1]
    
    self.FinalNN =  nn.Sequential(
            nn.Linear(merge_shape, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
    )
    
  def forward(self, list, score):
    listOut = self.FeatureList(list)
    scoreOut = self.FeatureScore(score)
    MergedTensor = torch.cat((listOut,scoreOut),1)
    return self.FinalNN(MergedTensor)

我有一个名为 calc_loss 的函数,在它的最后 return MSE 损失如下

  print(state_action_values.dtype)
  print(expected_state_action_values.dtype) 
  return nn.MSELoss()(state_action_values, expected_state_action_values)

并且打印分别显示 float32 和 float64。
当我 运行 loss.backward() 时出现错误,如下所示

LEARNING_RATE = 0.01
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

for i in range(50):
  optimizer.zero_grad()
  loss_v = calc_loss(sample(obs, 500, 200, 64), net, tgt_net)
  print(loss_v.dtype)
  print(loss_v)
  loss_v.backward()
  optimizer.step()

打印输出如下:
火炬.float64
张量(1887.4831,dtype=torch.float64,grad_fn=)

更新 1:
我尝试使用更简单的模型,但同样的问题,当我尝试将输入转换为 Float 时,出现错误:

RuntimeError: expected scalar type Double but found Float

是什么让模型期望加倍?

更新 2:
我试图在火炬导入后在顶部添加以下行,但 RuntimeError: Found dtype Double but expected Float

同样的问题
>>> torch.set_default_tensor_type(torch.FloatTensor)

但是当我使用 DoubleTensor 时,我得到: RuntimeError: Input type (torch.FloatTensor) and weight type (torch.DoubleTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

问题不在于网络的输入,而在于 MSELoss 的标准,因此在将标准转换为浮动后它工作正常,如下所示

return nn.MSELoss()(state_action_values.float(), expected_state_action_values.float())

我决定把答案留给像我这样可能被卡住,没想到会检查损失准则数据类型的初学者