nn.CDivTable 在调用向后时抛出错误是否有正当理由?

Is there a valid reason why nn.CDivTable throws error when backward is called?

我最近开始使用 Torch 框架和 Lua 脚本语言来研究神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的方法:

我的想法是我有 3 个输入,我必须选择前两个,将它们相除,然后将结果转发给线性模块。所以,我制作了这个小脚本:

require "nn";
require "optim";

local N = 3;

local input = torch.Tensor{
    {1, 2, 3},
    {9, 20, 20},
    {9, 300, 1},
};

local output = torch.Tensor(N);
for i=1, N do
    output[i] = 1;
end

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};

local maxIteration = 100000;
for i=1, maxIteration do
    local function f(params)
        gradParams:zero();

        local outputs = ratioPerceptron:forward(input);
        local loss = criterion:forward(outputs, output);
        local dloss_doutputs = criterion:backward(outputs, output);
        ratioPerceptron:backward(input, dloss_doutputs);

        return loss, gradParams;
    end

    optim.sgd(f, params, optimState);
end

训练期间调用 backward 时失败:

CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator

但是,如果我从顺序模块中删除 CDivTable,并将 nn.Reshape 和 nn.Linear 更改为二维输入(因为我们删除了 CDivTable,它将二维输入分成一维输出)像这样:

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());

训练无误结束...有没有其他方法可以划分两个选定的输入并将结果转发给线性模块?

模块CDivTable将table作为输入并将第一个table的元素除以第二个table的元素。在这里,您为网络提供单一输入,而不是两个输入的 table。我相信这就是 null 错误的原因。 Torch 无法理解您的输入(由两个向量组成)应被视为两个向量的 table。它只能看到大小为 2x3 的张量!因此,您必须告诉 Torch 从输入中创建一个 table。因此,您可以使用模块 SplitTable(dim) 将输入沿维度 dim.

拆分为 tables

在窄模块后插入这一行ratioPerceptron:add(nn.SplitTable(1))

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

此外,当您遇到此类错误时,我建议您通过放置 print 语句来查看网络计算的内容:在添加创建模块的行之前插入一行 print(ratioPerceptron:forward(input))一个错误。