NumPy 继承:更改数据类型时丢失自定义 class 实例中的属性

NumPy inheriting: losing attributes in custom class instance when changing datatype

是否有更多通用和一般最佳实践来处理numpy更改数据类型导致实例属性丢失时的问题?

(MyCustomClass 继承自 np.ndarray)

问题:

array = MyCustomClass(shape=[4,4,4],dtype=np.float16)
array.variable = 5
array_uint8 = np.uint8(array)
print(array_uint8.variable)
>> AttributeError: 'MyCustomClass' object has no attribute 'variable'

注意,更改数据类型不会导致更改数组的 class,即 MyCustomClass

我的解决方案非常糟糕

class uint8(np.uint8):
    def __new__(cls,*args, **kwargs):
        instance = super().__new__(cls, *args, **kwargs)
        print('instance class', instance.__class__)

        # Moving attributes from one instance to another
        instance.__dict__ = args[0].__dict__ # args[0] is MyCustomClass instance
        return instance

>>> instance class <class '__main__.MyCustomClass'>

我更喜欢这样一种解决方案,即我根本不必修改 np.uint8(或其他 numpy classes)。

感谢@kwinkunks,我找到了解决方案。

正如@kwinkunks 所指出的,有必要定义__array_finalize__。 在里面,可以在 obj 之间交换 __dict__ -> self 代表新数组,obj 代表旧数组

class MyCustomClass(np.ndarray):

    def __array_finalize__(self, obj):
        if isinstance(obj, (MyCustomClass,)):
            self.__dict__ = obj.__dict__