在 python 中使用多处理折叠数组突变

Folding in array mutations using multiprocessing in python

我有218k+ 33通道的图像,我需要找到每个通道的均值和方差。我试过使用多处理,但这似乎慢得令人难以忍受。这是一个简短的代码示例:

def work(aggregates, genput):
    # received (channel, image) from generator
    channel = genput[0]
    image = genput[1]
    for row in image:
        for pixel in row:
            # use welford's to update a list of "aggregates" which will
            # later be finalized as means and variances of each channel
            aggregates[channel] = update(aggregates[channel], pixel)

def data_stream(df, data_root):
    '''Generator that returns the channel and image for each tif file'''
    for index, sample in df.iterrows():
        curr_img_path = data_root

        # read the image with all channels
        tif = imread(curr_img_path)  #33x64x64 array        
        for channel, image in enumerate(tif):
            yield (channel, image)     

# Pass over each image, compute mean/variance for each channel for each image
def preprocess_mv(df, data_root, channels=33, multiprocessing=True):
    '''Calculates mean and variance on the whole image set for use in deep_learn'''
    manager = Manager()
    aggregates = manager.list()

    [aggregates.append(([0,0,0])) for i in range(channels)]

    proxy = partial(work, aggregates)

    pool = Pool(processes=8) 
    pool.imap(proxy, data_stream(df, data_root), chunksize=5000)
    pool.close()
    pool.join()

    # finalize data below

我怀疑 pickle aggregates 数组并将其从 parent 进程来回传输到 child 进程所花费的时间非常长,而且这是主要的瓶颈——我可以看到这个缺点完全消除了 multi-process 的优势,因为每个 child 都必须等待其他 children 来 pickle 和 unpickle 数据。我读到这是多处理库的一种限制,从我在这里阅读其他帖子的文章中,我开始意识到这可能是我能做的最好的。也就是说,有人对如何改进有什么建议吗?

此外,我想知道是否有更好的 libraries/tools 来完成这项任务?一位朋友实际上推荐了 Scala,我一直在研究将其作为一种选择。我对 Python 非常熟悉,如果可能的话我想留在这个领域。

我能够通过更深入地探索 multiprocessing.Array 来找到解决方案。我必须弄清楚如何将我的 2D 数组转换为 1D 数组并仍然进行索引工作,但这最终变得非常简单。我现在可以在 2 分钟而不是 4 小时内处理 1000 个样本,所以我认为这非常好。我还必须编写一个自定义函数来打印数组,但这相当简单。此实现不保证不会出现竞争条件,但就我的目的而言,它工作得相当好。您可以通过将锁包含在 init 中并以与处理数组相同的方式(使用 global)传递它来轻松添加锁。

def init(arr):
    global aggregates
    aggregates = arr

def work(genput):
    # received (sample, channel, image) from generator
    sample_no = genput[0]
    channel = genput[1]
    image = genput[2]
    currAgg =  (aggregates[3*channel], aggregates[3*channel+1], 
                aggregates[3*channel+2])
    for row in image:
        for pixel in row:
            # use welford's to compute updated aggregate
            newAgg = update(currAgg, pixel)
            currAgg = newAgg
    # New method of indexing for 1D array ("shaped" as 33x3)
    aggregates[3*channel] = newAgg[0]
    aggregates[(3*channel)+1] = newAgg[1]
    aggregates[(3*channel)+2] = newAgg[2]

def data_stream(df, data_root):
    '''Generator that returns the channel and image for each tif file'''
    ...
    yield (index, channel, image)


if __name__ == '__main__':

    aggs = Array('d', np.zeros(99)) #99 values for all aggrs

    pool = Pool(initializer=init, initargs=(aggs,), processes=8)
    pool.imap(work, data_stream(df, data_root), chunksize=10)
    pool.close()
    pool.join()

#-----------finalize aggregates below