Pytorch FloatClassNLLCriterion_updateOutput 错误

Pytorch FloatClassNLLCriterion_updateOutput error

我在计算神经网络的损失时收到以下错误消息:TypeError: FloatClassNLLCriterion_updateOutput 收到无效的参数组合 - got (int, torch.FloatTensor, !torch.FloatTensor!, torch.FloatTensor, bool, NoneType, torch.FloatTensor, int), but expected (int state, torch.FloatTensor input, torch.LongTensor目标,torch.FloatTensor 输出,bool sizeAverage,[torch.FloatTensor 权重或 None],torch.FloatTensor total_weight,int ignore_index)在这条线上:loss = criterion(outputs,one_hot_target)。

我尝试了几种方法并在网上进行了搜索,但似乎找不到我的错误。有人有想法吗?

使用的代码:

class Net(nn.Module):
def _init_(self,input_size,hidden_size, num_classes):
    super(Net, self)._init_()
    self.l1 = nn.Linear(input_size,hidden_size)
    self.l2 = nn.Linear(hidden_size,num_classes)

def forward(self,x):
    x = self.l1(x)
    x = F.tanh(x)
    x = self.l2(x)
    x = F.softmax(x)
    return x

mlp = Net(input_size,hidden_size,num_classes)

criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(mlp.parameters(), lr=learning_rate)

for i in range(N-1,len(word_list)):

  # Define input vector x
  x = None


  for j in range(0,N-1):
      try:
          x = np.r_[x,w2v[word_list[j]]]
      except:
          x = w2v[word_list[j-N+1]]

  # Done with defining  x


  np.reshape(x,(len(x),1))
  x = autograd.Variable(torch.FloatTensor(x))
  optimizer.zero_grad()

  outputs = mlp(x)
  outputs = outputs.unsqueeze(0)
  outputs = outputs.transpose(0,1)

  index = w2i[word_list[i]]
  one_hot_target = np.zeros([num_classes,1],dtype=float)

  one_hot_target[index] = float(1)

  one_hot_target = autograd.Variable(torch.Tensor(one_hot_target))
  print (one_hot_target)
  loss = criterion(outputs,one_hot_target)
  #loss.backward()
  #optimizer.step()

您传递给损失函数的目标张量是 Float Tensor 类型。这是发生在这里: one_hot_target = autograd.Variable(torch.Tensor(one_hot_target))

也就是因为PyTorch中默认的Tensor类型是FloatTensor。 但是,NLLLoss() 函数需要一个 LongTensor 作为目标。您可以仔细检查文档中的示例,您可以在此处找到:NLLLoss Pytorch docs.

您可以像这样简单地将目标 Tensor 转换为 LongTensor:

one_hot_target = autograd.Variable(torch.Tensor(one_hot_target)).long()