带有 attrs 的 PyTorch 模块无法获取参数列表
PyTorch Module with attrs cannot get parameter list
attr 的包以某种方式破坏了 pytorch 的 parameter()
模块方法。我想知道是否有人有任何解决方法或解决方案,以便这两个包可以无缝集成?
如果没有,请教一下github到post哪个问题要解决?我的直觉是 post 把它放到 attr 的 github 上,但堆栈跟踪几乎完全与 pytorch 的代码库相关。
Python 3.7.3
attrs== 19.1.0
torch==1.1.0.post2
torchvision==0.3.0
import attr
import torch
class RegularModule(torch.nn.Module):
pass
@attr.s
class AttrsModule(torch.nn.Module):
pass
module = RegularModule()
print(list(module.parameters()))
module = AttrsModule()
print(list(module.parameters()))
实际输出为:
$python attrs_pytorch.py
[]
Traceback (most recent call last):
File "attrs_pytorch.py", line 18, in <module>
print(list(module.parameters()))
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 814, in parameters
for name, param in self.named_parameters(recurse=recurse):
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 840, in named_parameters
for elem in gen:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 784, in _named_members
for module_prefix, module in modules:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 975, in named_modules
if self not in memo:
TypeError: unhashable type: 'AttrsModule'
预期输出为:
$python attrs_pytorch.py
[]
[]
您可以使用一种解决方法并使用 dataclasses
(您应该这样做,因为它在标准 Python 库中,因为您显然正在使用 3.7
)。虽然我认为简单 __init__
更具可读性。可以使用 attrs
库(禁用散列)做类似的事情,我更喜欢尽可能使用标准库的解决方案。
原因(如果您设法处理与散列相关的错误)是您正在调用 torch.nn.Module.__init__()
,它生成 _parameters
属性和其他特定于框架的数据。
首先用dataclasses
解决散列问题:
@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
pass
这解决了 hashing
问题,如 documentation
部分所述,关于 hash
和 eq
:
By default, dataclass() will not implicitly add a hash() method
unless it is safe to do so.
PyTorch 需要它,因此该模型可以在 C++ 支持下使用(如果我错了请纠正我),此外:
If eq is false, hash() will be left untouched meaning the
hash() method of the superclass will be used (if the superclass is object, this means it will fall back to id-based hashing).
所以你可以使用 torch.nn.Module
__hash__
函数(如果出现任何进一步的错误,请参阅数据类的文档)。
这给您留下了错误:
AttributeError: 'AttrsModule' object has no attribute '_parameters'
因为torch.nn.Module
构造函数没有被调用。快速而肮脏的修复:
@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
def __post_init__(self):
super().__init__()
__post_init__
是在 __init__
之后调用的函数(谁会猜到),您可以在其中初始化 torch 特定参数。
不过,我还是建议反对 同时使用这两个模块。例如,你正在使用你的代码破坏 PyTorch 的 __repr__
,所以 repr=False
应该传递给 dataclasses.dataclass
构造函数,它给出了这个最终代码(我希望消除库之间的明显冲突):
import dataclasses
import torch
class RegularModule(torch.nn.Module):
pass
@dataclasses.dataclass(eq=False, repr=False)
class AttrsModule(torch.nn.Module):
def __post_init__(self):
super().__init__()
module = RegularModule()
print(list(module.parameters()))
module = AttrsModule()
print(list(module.parameters()))
有关 attrs
的更多信息,请参阅 答案和他的博客 post。
attrs
有一个关于散列性的章节,其中还解释了 Python 中散列的陷阱:https://www.attrs.org/en/stable/hashing.html
您必须决定哪种行为适合您的具体问题。有关更多一般信息,请查看 https://hynek.me/articles/hashes-and-equality/ — 结果表明 Python.
中的散列非常棘手
attr 的包以某种方式破坏了 pytorch 的 parameter()
模块方法。我想知道是否有人有任何解决方法或解决方案,以便这两个包可以无缝集成?
如果没有,请教一下github到post哪个问题要解决?我的直觉是 post 把它放到 attr 的 github 上,但堆栈跟踪几乎完全与 pytorch 的代码库相关。
Python 3.7.3
attrs== 19.1.0
torch==1.1.0.post2
torchvision==0.3.0
import attr
import torch
class RegularModule(torch.nn.Module):
pass
@attr.s
class AttrsModule(torch.nn.Module):
pass
module = RegularModule()
print(list(module.parameters()))
module = AttrsModule()
print(list(module.parameters()))
实际输出为:
$python attrs_pytorch.py
[]
Traceback (most recent call last):
File "attrs_pytorch.py", line 18, in <module>
print(list(module.parameters()))
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 814, in parameters
for name, param in self.named_parameters(recurse=recurse):
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 840, in named_parameters
for elem in gen:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 784, in _named_members
for module_prefix, module in modules:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 975, in named_modules
if self not in memo:
TypeError: unhashable type: 'AttrsModule'
预期输出为:
$python attrs_pytorch.py
[]
[]
您可以使用一种解决方法并使用 dataclasses
(您应该这样做,因为它在标准 Python 库中,因为您显然正在使用 3.7
)。虽然我认为简单 __init__
更具可读性。可以使用 attrs
库(禁用散列)做类似的事情,我更喜欢尽可能使用标准库的解决方案。
原因(如果您设法处理与散列相关的错误)是您正在调用 torch.nn.Module.__init__()
,它生成 _parameters
属性和其他特定于框架的数据。
首先用dataclasses
解决散列问题:
@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
pass
这解决了 hashing
问题,如 documentation
部分所述,关于 hash
和 eq
:
By default, dataclass() will not implicitly add a hash() method unless it is safe to do so.
PyTorch 需要它,因此该模型可以在 C++ 支持下使用(如果我错了请纠正我),此外:
If eq is false, hash() will be left untouched meaning the hash() method of the superclass will be used (if the superclass is object, this means it will fall back to id-based hashing).
所以你可以使用 torch.nn.Module
__hash__
函数(如果出现任何进一步的错误,请参阅数据类的文档)。
这给您留下了错误:
AttributeError: 'AttrsModule' object has no attribute '_parameters'
因为torch.nn.Module
构造函数没有被调用。快速而肮脏的修复:
@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
def __post_init__(self):
super().__init__()
__post_init__
是在 __init__
之后调用的函数(谁会猜到),您可以在其中初始化 torch 特定参数。
不过,我还是建议反对 同时使用这两个模块。例如,你正在使用你的代码破坏 PyTorch 的 __repr__
,所以 repr=False
应该传递给 dataclasses.dataclass
构造函数,它给出了这个最终代码(我希望消除库之间的明显冲突):
import dataclasses
import torch
class RegularModule(torch.nn.Module):
pass
@dataclasses.dataclass(eq=False, repr=False)
class AttrsModule(torch.nn.Module):
def __post_init__(self):
super().__init__()
module = RegularModule()
print(list(module.parameters()))
module = AttrsModule()
print(list(module.parameters()))
有关 attrs
的更多信息,请参阅
attrs
有一个关于散列性的章节,其中还解释了 Python 中散列的陷阱:https://www.attrs.org/en/stable/hashing.html
您必须决定哪种行为适合您的具体问题。有关更多一般信息,请查看 https://hynek.me/articles/hashes-and-equality/ — 结果表明 Python.
中的散列非常棘手