NumPy 继承:更改数据类型时丢失自定义 class 实例中的属性
NumPy inheriting: losing attributes in custom class instance when changing datatype
是否有更多通用和一般最佳实践来处理numpy更改数据类型导致实例属性丢失时的问题?
(MyCustomClass 继承自 np.ndarray)
问题:
array = MyCustomClass(shape=[4,4,4],dtype=np.float16)
array.variable = 5
array_uint8 = np.uint8(array)
print(array_uint8.variable)
>> AttributeError: 'MyCustomClass' object has no attribute 'variable'
注意,更改数据类型不会导致更改数组的 class,即 MyCustomClass
我的解决方案非常糟糕
class uint8(np.uint8):
def __new__(cls,*args, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
print('instance class', instance.__class__)
# Moving attributes from one instance to another
instance.__dict__ = args[0].__dict__ # args[0] is MyCustomClass instance
return instance
>>> instance class <class '__main__.MyCustomClass'>
我更喜欢这样一种解决方案,即我根本不必修改 np.uint8(或其他 numpy classes)。
感谢@kwinkunks,我找到了解决方案。
正如@kwinkunks 所指出的,有必要定义__array_finalize__。
在里面,可以在 obj 之间交换 __dict__ -> self 代表新数组,obj 代表旧数组
class MyCustomClass(np.ndarray):
def __array_finalize__(self, obj):
if isinstance(obj, (MyCustomClass,)):
self.__dict__ = obj.__dict__
是否有更多通用和一般最佳实践来处理numpy更改数据类型导致实例属性丢失时的问题?
(MyCustomClass 继承自 np.ndarray)
问题:
array = MyCustomClass(shape=[4,4,4],dtype=np.float16)
array.variable = 5
array_uint8 = np.uint8(array)
print(array_uint8.variable)
>> AttributeError: 'MyCustomClass' object has no attribute 'variable'
注意,更改数据类型不会导致更改数组的 class,即 MyCustomClass
我的解决方案非常糟糕
class uint8(np.uint8):
def __new__(cls,*args, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
print('instance class', instance.__class__)
# Moving attributes from one instance to another
instance.__dict__ = args[0].__dict__ # args[0] is MyCustomClass instance
return instance
>>> instance class <class '__main__.MyCustomClass'>
我更喜欢这样一种解决方案,即我根本不必修改 np.uint8(或其他 numpy classes)。
感谢@kwinkunks,我找到了解决方案。
正如@kwinkunks 所指出的,有必要定义__array_finalize__。 在里面,可以在 obj 之间交换 __dict__ -> self 代表新数组,obj 代表旧数组
class MyCustomClass(np.ndarray):
def __array_finalize__(self, obj):
if isinstance(obj, (MyCustomClass,)):
self.__dict__ = obj.__dict__