如何在火炬中向图形模块添加附加层

How to add additional layers to a graph module in torch

如何将新节点添加到 torch 的 nngraph 包中的图形模块 (gModule)?我尝试使用 add 函数,这将节点添加到 gModules 对象中的模块槽中。但是输出仍然取自前一个最后一个节点。

简化代码:

require "nn"
require "nngraph"

-- Function that builds a gModule
function buildModule(input_size,hidden_size)
    local x = nn.Identity()()
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
    return nn.gModule({x},{out})
end

network = buildModule(5,3)
-- Additional layer to add
l2 = nn.Linear(3,10)
network:add(l2)

-- Expected a tensor of size 10 but got one with size 3
print(network:forward(torch.randn(5)))

gModule 实际上不应该被改变。它支持 :add 实际上是作为 nn.Container 的子 class 的副作用,而不是设计决定。一般来说,一旦您创建了 gModule,您就不应修改其内部结构,因为您必须修改一些内部属性才能使其正常工作。相反 - 如果你想添加一些东西 "on top" 只需定义新的容器,将前一个容器作为输入。

-- Function that builds a gModule
function buildModule(input_size,hidden_size)
    local x = nn.Identity()()
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
    return nn.gModule({x},{out})
end

network = buildModule(5,3)

new_network = nn.Sequential()
new_network:add(network)
new_network:add(nn.Linear(3,10))