pytorch nn.module 如何保存子模块
how pytorch nn.module save submodule
我对 pytorch nn.module 的工作原理有一些疑问
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.sub_module = nn.Linear(10, 5)
self.value = 3
net = Net()
print(net.__dict__)
输出
{'_modules': OrderedDict([('sub_module', Linear (10 -> 5))]), 'value': 3, ...}
我知道class的每个属性都应该存储在__dict__中,为什么value(a int value)在里面,但是sub_module(a nn.Module) 不是,而是 sub_module 存储在 _modules
我阅读了 nn.Module 实现的代码,但我没有弄明白。有人有什么想法吗?
谢谢!!
我会尽量保持简单。
每次您在 class Net
中创建一个新项目时,例如:self.sub_module = nn.Linear(10, 5)
它会调用其父 class 的方法 __setattr__
,在这种情况下 nn.Module
。然后,在 __setattr__
方法内部,参数被存储到它们所属的字典中。在这种情况下,因为 nn.Linear
是一个模块,所以它存储在 _modules
字典中。
这是在 Module
class https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389
中执行此操作的一段代码
我对 pytorch nn.module 的工作原理有一些疑问
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.sub_module = nn.Linear(10, 5)
self.value = 3
net = Net()
print(net.__dict__)
输出
{'_modules': OrderedDict([('sub_module', Linear (10 -> 5))]), 'value': 3, ...}
我知道class的每个属性都应该存储在__dict__中,为什么value(a int value)在里面,但是sub_module(a nn.Module) 不是,而是 sub_module 存储在 _modules
我阅读了 nn.Module 实现的代码,但我没有弄明白。有人有什么想法吗?
谢谢!!
我会尽量保持简单。
每次您在 class Net
中创建一个新项目时,例如:self.sub_module = nn.Linear(10, 5)
它会调用其父 class 的方法 __setattr__
,在这种情况下 nn.Module
。然后,在 __setattr__
方法内部,参数被存储到它们所属的字典中。在这种情况下,因为 nn.Linear
是一个模块,所以它存储在 _modules
字典中。
这是在 Module
class https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389