multiprocessing.Pool.map() 删除子类 ndarray 的属性

multiprocessing.Pool.map() drops attribute of subclassed ndarray

当在 numpy.ndarray-subclass 的实例列表上使用 multiprocessing.Pool()map() 时,自己的 class 的新属性是掉线了。

以下基于 numpy docs subclassing example 的最小示例重现了该问题:

from multiprocessing import Pool
import numpy as np


class MyArray(np.ndarray):

    def __new__(cls, input_array, info=None):
        obj = np.asarray(input_array).view(cls)
        obj.info = info
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'info', None)

def sum_worker(x):
    return sum(x) , x.info

if __name__ == '__main__':
    arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
    with Pool() as p:
        p.map(sum_worker, arr_list)

属性 info 已删除

AttributeError: 'MyArray' object has no attribute 'info'

使用内置 map() 工作正常

arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
list(map(sum_worker, arr_list2))

方法__array_finalize__()的目的是对象在切片后保留属性

arr = MyArray([1,2,3], info='foo')
subarr = arr[:2]
print(subarr.info)

但是对于 Pool.map() 这种方法不知何故不起作用...

因为 multiprocessing 使用 pickle 序列化数据 to/from 单独的进程,这本质上是 this question 的副本。

根据该问题采用公认的解决方案,您的示例变为:

from multiprocessing import Pool
import numpy as np

class MyArray(np.ndarray):

    def __new__(cls, input_array, info=None):
        obj = np.asarray(input_array).view(cls)
        obj.info = info
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'info', None)

    def __reduce__(self):
        pickled_state = super(MyArray, self).__reduce__()
        new_state = pickled_state[2] + (self.info,)
        return (pickled_state[0], pickled_state[1], new_state)

    def __setstate__(self, state):
        self.info = state[-1]
        super(MyArray, self).__setstate__(state[0:-1])

def sum_worker(x):
    return sum(x) , x.info

if __name__ == '__main__':
    arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
    with Pool() as p:
        p.map(sum_worker, arr_list)

请注意,第二个答案表明您可以将 pathos.multiprocessing 与未改编的原始代码一起使用,因为 pathos 使用 dill 而不是 pickle。但是,当我测试它时,这不起作用。