将一个 class 的输出分享给另一个 class python
Share the output of one class to another class python
我有两个 DNN,第一个 returns 两个输出。我想在第二个 class 中使用其中一个输出代表另一个 DNN,如下例所示:
我想将输出 (x) 传递给第二个 class 以连接到另一个变量 (v)。我找到了将变量 (x) 作为全局变量的解决方案,但我需要另一个有效的解决方案
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
..
def forward(self, x):
..
return x, z
class Net2(nn.Module):
def __init__(self):
..
def forward(self, v):
y = torch.cat(v, x)
return y
你不应该依赖 global 变量,你需要解决以下常见问题。您可以将 v
和 x
作为 Net2
的 forward
的参数传递。类似于:
class Net(nn.Module):
def forward(self, x):
z = x**2
return x, z
class Net2(nn.Module):
def forward(self, x, v):
y = torch.cat((v, x), dim=1)
return y
使用虚拟数据:
>>> net = Net()
>>> net2 = Net2()
>>> input1 = torch.rand(1,10)
>>> input2 = torch.rand(1,20)
第一个推论:
>>> x, z = net(input1)
第二个推断:
>>> out = net2(x, input2)
>>> out.shape
torch.Size([1, 30])
我有两个 DNN,第一个 returns 两个输出。我想在第二个 class 中使用其中一个输出代表另一个 DNN,如下例所示:
我想将输出 (x) 传递给第二个 class 以连接到另一个变量 (v)。我找到了将变量 (x) 作为全局变量的解决方案,但我需要另一个有效的解决方案
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
..
def forward(self, x):
..
return x, z
class Net2(nn.Module):
def __init__(self):
..
def forward(self, v):
y = torch.cat(v, x)
return y
你不应该依赖 global 变量,你需要解决以下常见问题。您可以将 v
和 x
作为 Net2
的 forward
的参数传递。类似于:
class Net(nn.Module):
def forward(self, x):
z = x**2
return x, z
class Net2(nn.Module):
def forward(self, x, v):
y = torch.cat((v, x), dim=1)
return y
使用虚拟数据:
>>> net = Net()
>>> net2 = Net2()
>>> input1 = torch.rand(1,10)
>>> input2 = torch.rand(1,20)
第一个推论:
>>> x, z = net(input1)
第二个推断:
>>> out = net2(x, input2)
>>> out.shape
torch.Size([1, 30])