如何使用 class 中的私有方法反序列化 PyTorch 保存的模型?
How to deserialize a PyTorch saved model with private methods inside a class?
我用PyTorch的保存方式序列化了一堆必不可少的对象。其中,有一个 class 在同一个 class 的 __init__ 中引用了一个私有方法。现在,在序列化之后,我无法反序列化(unpickle)文件,因为在 class 之外无法访问私有方法。知道如何解决或绕过它吗?我需要恢复保存到 class.
属性中的数据
File ".conda/envs/py37/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-a5666d77c70f>", line 1, in <module>
torch.load("snapshots/model.pth", map_location='cpu')
File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 529, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 702, in _legacy_load
result = unpickler.load()
AttributeError: 'Trainer' object has no attribute '__iterator'
- 编辑-1:
这里有一段代码会产生我现在面临的问题。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.__private # buggy
def __private(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
但是,如果从方法中删除私有属性,则不会出现任何错误。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.private # not buggy
def private(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
此题与类似,但由于悬赏开放,无法标记为重复。
问题源于 Python 错误跟踪器上的这个开放问题:Objects referencing private-mangled names do not roundtrip properly under pickling, and is related to the way pickle handles name-mangling. More details on this answer: 。
此时,唯一的解决方法是不使用 __init__
中的私有方法。
这个问题是由于 name mangling — 解释器以下面的方式更改变量的名称,这使得当 class 是以后延长。其中
self.__private
已更改为 (self._className__privateMethodName)
self._Test__private
由于 name mangling 不适用于 dunder,其中名称必须以双下划线开头和结尾。
因此,为避免名称混淆,在末尾再添加两个下划线。
下面的代码片段应该可以解决您的问题。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.__private__
def __private__(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
我用PyTorch的保存方式序列化了一堆必不可少的对象。其中,有一个 class 在同一个 class 的 __init__ 中引用了一个私有方法。现在,在序列化之后,我无法反序列化(unpickle)文件,因为在 class 之外无法访问私有方法。知道如何解决或绕过它吗?我需要恢复保存到 class.
属性中的数据 File ".conda/envs/py37/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-a5666d77c70f>", line 1, in <module>
torch.load("snapshots/model.pth", map_location='cpu')
File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 529, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 702, in _legacy_load
result = unpickler.load()
AttributeError: 'Trainer' object has no attribute '__iterator'
- 编辑-1:
这里有一段代码会产生我现在面临的问题。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.__private # buggy
def __private(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
但是,如果从方法中删除私有属性,则不会出现任何错误。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.private # not buggy
def private(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
此题与
问题源于 Python 错误跟踪器上的这个开放问题:Objects referencing private-mangled names do not roundtrip properly under pickling, and is related to the way pickle handles name-mangling. More details on this answer:
此时,唯一的解决方法是不使用 __init__
中的私有方法。
这个问题是由于 name mangling — 解释器以下面的方式更改变量的名称,这使得当 class 是以后延长。其中
self.__private
已更改为 (self._className__privateMethodName)
self._Test__private
由于 name mangling 不适用于 dunder,其中名称必须以双下划线开头和结尾。
因此,为避免名称混淆,在末尾再添加两个下划线。
下面的代码片段应该可以解决您的问题。
import torch
class Test:
def __init__(self):
self.a = min
self.b = max
self.c = self.__private__
def __private__(self):
return None
test = Test()
torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")