nn.CDivTable 在调用向后时抛出错误是否有正当理由?
Is there a valid reason why nn.CDivTable throws error when backward is called?
我最近开始使用 Torch 框架和 Lua 脚本语言来研究神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的方法:
我的想法是我有 3 个输入,我必须选择前两个,将它们相除,然后将结果转发给线性模块。所以,我制作了这个小脚本:
require "nn";
require "optim";
local N = 3;
local input = torch.Tensor{
{1, 2, 3},
{9, 20, 20},
{9, 300, 1},
};
local output = torch.Tensor(N);
for i=1, N do
output[i] = 1;
end
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};
local maxIteration = 100000;
for i=1, maxIteration do
local function f(params)
gradParams:zero();
local outputs = ratioPerceptron:forward(input);
local loss = criterion:forward(outputs, output);
local dloss_doutputs = criterion:backward(outputs, output);
ratioPerceptron:backward(input, dloss_doutputs);
return loss, gradParams;
end
optim.sgd(f, params, optimState);
end
训练期间调用 backward 时失败:
CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator
但是,如果我从顺序模块中删除 CDivTable,并将 nn.Reshape 和 nn.Linear 更改为二维输入(因为我们删除了 CDivTable,它将二维输入分成一维输出)像这样:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());
训练无误结束...有没有其他方法可以划分两个选定的输入并将结果转发给线性模块?
模块CDivTable
将table作为输入并将第一个table的元素除以第二个table的元素。在这里,您为网络提供单一输入,而不是两个输入的 table。我相信这就是 null
错误的原因。 Torch 无法理解您的输入(由两个向量组成)应被视为两个向量的 table。它只能看到大小为 2x3
的张量!因此,您必须告诉 Torch 从输入中创建一个 table。因此,您可以使用模块 SplitTable(dim)
将输入沿维度 dim
.
拆分为 tables
在窄模块后插入这一行ratioPerceptron:add(nn.SplitTable(1))
:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
此外,当您遇到此类错误时,我建议您通过放置 print
语句来查看网络计算的内容:在添加创建模块的行之前插入一行 print(ratioPerceptron:forward(input))
一个错误。
我最近开始使用 Torch 框架和 Lua 脚本语言来研究神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的方法:
我的想法是我有 3 个输入,我必须选择前两个,将它们相除,然后将结果转发给线性模块。所以,我制作了这个小脚本:
require "nn";
require "optim";
local N = 3;
local input = torch.Tensor{
{1, 2, 3},
{9, 20, 20},
{9, 300, 1},
};
local output = torch.Tensor(N);
for i=1, N do
output[i] = 1;
end
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};
local maxIteration = 100000;
for i=1, maxIteration do
local function f(params)
gradParams:zero();
local outputs = ratioPerceptron:forward(input);
local loss = criterion:forward(outputs, output);
local dloss_doutputs = criterion:backward(outputs, output);
ratioPerceptron:backward(input, dloss_doutputs);
return loss, gradParams;
end
optim.sgd(f, params, optimState);
end
训练期间调用 backward 时失败:
CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator
但是,如果我从顺序模块中删除 CDivTable,并将 nn.Reshape 和 nn.Linear 更改为二维输入(因为我们删除了 CDivTable,它将二维输入分成一维输出)像这样:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());
训练无误结束...有没有其他方法可以划分两个选定的输入并将结果转发给线性模块?
模块CDivTable
将table作为输入并将第一个table的元素除以第二个table的元素。在这里,您为网络提供单一输入,而不是两个输入的 table。我相信这就是 null
错误的原因。 Torch 无法理解您的输入(由两个向量组成)应被视为两个向量的 table。它只能看到大小为 2x3
的张量!因此,您必须告诉 Torch 从输入中创建一个 table。因此,您可以使用模块 SplitTable(dim)
将输入沿维度 dim
.
在窄模块后插入这一行ratioPerceptron:add(nn.SplitTable(1))
:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
此外,当您遇到此类错误时,我建议您通过放置 print
语句来查看网络计算的内容:在添加创建模块的行之前插入一行 print(ratioPerceptron:forward(input))
一个错误。