使用两个参数在pytorch中调用自定义模块

Calling a custom module in pythorch with two parameters

我尝试创建三个自定义模块,如下所示:

import torch

class VerySimple(torch.nn.Module):
  def __init__(self):
    super(VerySimple, self).__init__()

  def forward(self, x):
    return x * 3.0

class VerySimple2(torch.nn.Module):
  def __init__(self):
    super(VerySimple, self).__init__()

  def forward(self, x, y):
    return x * y * 3.0

之后我创建了两个非常简单的网络:

vs = VerySimple()
vs2 = VerySimple2()
print(vs(2.0))
print(vs2(2.0, 3.0))

当我调用它输出 6.0 和 18.0 时,示例按预期工作

现在我试着创造一些更有趣的东西:

class Simple2(torch.nn.Module):
  def __init__(self):
    super(Simple2, self).__init__()
    self.model1 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )
    self.model2 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )

  def forward(self, x, y):
    x1 = self.model1(x)
    y2 = self.model2(y)
    return torch.cat((x1,y2),1)

但是现在当我收到一个带有以下代码的“AttributeError”时:

s2 = Simple2()
s2(2,3)

我对 s2(2,3) 做错了什么?

或者:s2(2,3) 的最小工作示例是什么?

根据要求,我在此处添加完整日志:

6.0
18.0
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-3-f3b0dc51220a> in <module>
     43 
     44 s2 = Simple2()
---> 45 s2(2,3)

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-3-f3b0dc51220a> in forward(self, x, y)
     38 
     39   def forward(self, x, y):
---> 40     x1 = self.model1(x)
     41     y2 = self.model2(y)
     42     return torch.cat((x1,y1),1)

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     91 
     92     def forward(self, input: Tensor) -> Tensor:
---> 93         return F.linear(input, self.weight, self.bias)
     94 
     95     def extra_repr(self) -> str:

/opt/app-root/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1686         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
   1687             return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1688     if input.dim() == 2 and bias is not None:
   1689         # fused op is marginally faster
   1690         ret = torch.addmm(bias, input, weight.t())

AttributeError: 'int' object has no attribute 'dim'

我尝试了下面来自 Tamir 的张量示例:

x = torch.tensor([2.0])
y = torch.tensor([3.0])
s2(x,y)

但我最终遇到了这个错误:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-8-c61a0803c9b9> in <module>
     43 
     44 s2 = Simple2()
---> 45 s2(torch.tensor([2.0]), torch.tensor([3.0]))
     46 

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-8-c61a0803c9b9> in forward(self, x, y)
     40     x1 = self.model1(x)
     41     y2 = self.model2(y)
---> 42     return torch.cat((x1,y2),1)
     43 
     44 s2 = Simple2()

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

最后注意: 我不得不将 Simple2 示例修改为此,而不是让它与 Tamir 的解决方案一起工作:

class Simple2(torch.nn.Module):
  def __init__(self):
    super(Simple2, self).__init__()
    self.model1 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )
    self.model2 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )

  def forward(self, x, y):
    x1 = self.model1(x)
    y2 = self.model2(y)
    # replaced this with row below: return torch.cat((x1,y2),1)
    return x1 + y2

这可能是一个类型问题,Pytorch Linear 和 ReLU 层期望 Tensors 作为输入,而您传递的是整数。 做一些像

x = torch.tensor([2])
y = torch.tensor([3])
s2(x,y)