如何为 ModuleList 中的每个模块命名?
How can I give a name to each module in ModuleList?
我的模型中有以下组件:
feedfnn = []
for task_name, num_class in self.tasks:
if self.config.nonlinear_fc:
ffnn = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(self.config.dropout_fc)),
('dense1', nn.Linear(self.config.nhid * self.num_directions * 8, self.config.fc_dim)),
('tanh', nn.Tanh()),
('dropout2', nn.Dropout(self.config.dropout_fc)),
('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
('tanh', nn.Tanh()),
('dropout3', nn.Dropout(self.config.dropout_fc)),
('dense3', nn.Linear(self.config.fc_dim, num_class))
]))
else:
ffnn = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(self.config.dropout_fc)),
('dense1', nn.Linear(self.config.nhid * self.num_directions * 8, self.config.fc_dim)),
('dropout2', nn.Dropout(self.config.dropout_fc)),
('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
('dropout3', nn.Dropout(self.config.dropout_fc)),
('dense3', nn.Linear(self.config.fc_dim, num_class))
]))
feedfnn.append(ffnn)
self.ffnn = nn.ModuleList(feedfnn)
当我打印我的模型时,我得到以上组件的描述:
(ffnn): ModuleList (
(0): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 2)
)
(1): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 3)
)
(2): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 3)
)
)
我可以使用 (task1): Sequential
、(task2): Sequential
之类的特定名称来代替 (0): Sequential
、(1): Sequential
吗?
这很简单。
只需从空 ModuleList
开始,然后使用 add_module
即可。例如,
import torch.nn as nn
from collections import OrderedDict
final_module_list = nn.ModuleList()
a_sequential_module_with_names = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(0.1)),
('dense1', nn.Linear(10, 10)),
('tanh', nn.Tanh()),
('dropout2', nn.Dropout(0.1)),
('dense2', nn.Linear(10, 10)),
('tanh', nn.Tanh()),
('dropout3', nn.Dropout(0.1)),
('dense3', nn.Linear(10, 10))]))
final_module_list.add_module('Stage 1', a_sequential_module_with_names)
final_module_list.add_module('Stage 2', a_sequential_module_with_names)
etc.
我的模型中有以下组件:
feedfnn = []
for task_name, num_class in self.tasks:
if self.config.nonlinear_fc:
ffnn = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(self.config.dropout_fc)),
('dense1', nn.Linear(self.config.nhid * self.num_directions * 8, self.config.fc_dim)),
('tanh', nn.Tanh()),
('dropout2', nn.Dropout(self.config.dropout_fc)),
('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
('tanh', nn.Tanh()),
('dropout3', nn.Dropout(self.config.dropout_fc)),
('dense3', nn.Linear(self.config.fc_dim, num_class))
]))
else:
ffnn = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(self.config.dropout_fc)),
('dense1', nn.Linear(self.config.nhid * self.num_directions * 8, self.config.fc_dim)),
('dropout2', nn.Dropout(self.config.dropout_fc)),
('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
('dropout3', nn.Dropout(self.config.dropout_fc)),
('dense3', nn.Linear(self.config.fc_dim, num_class))
]))
feedfnn.append(ffnn)
self.ffnn = nn.ModuleList(feedfnn)
当我打印我的模型时,我得到以上组件的描述:
(ffnn): ModuleList (
(0): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 2)
)
(1): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 3)
)
(2): Sequential (
(dropout1): Dropout (p = 0)
(dense1): Linear (4096 -> 512)
(dropout2): Dropout (p = 0)
(dense2): Linear (512 -> 512)
(dropout3): Dropout (p = 0)
(dense3): Linear (512 -> 3)
)
)
我可以使用 (task1): Sequential
、(task2): Sequential
之类的特定名称来代替 (0): Sequential
、(1): Sequential
吗?
这很简单。
只需从空 ModuleList
开始,然后使用 add_module
即可。例如,
import torch.nn as nn
from collections import OrderedDict
final_module_list = nn.ModuleList()
a_sequential_module_with_names = nn.Sequential(OrderedDict([
('dropout1', nn.Dropout(0.1)),
('dense1', nn.Linear(10, 10)),
('tanh', nn.Tanh()),
('dropout2', nn.Dropout(0.1)),
('dense2', nn.Linear(10, 10)),
('tanh', nn.Tanh()),
('dropout3', nn.Dropout(0.1)),
('dense3', nn.Linear(10, 10))]))
final_module_list.add_module('Stage 1', a_sequential_module_with_names)
final_module_list.add_module('Stage 2', a_sequential_module_with_names)
etc.