如何在递归工作的池中维护全局进程?

How to maintain global processes in a pool working recursively?

我想实现一个递归并行算法,我希望只创建一次池,每个时间步执行一个作业,等待所有作业完成,然后再次调用进程,输入之前的输出,然后在下一个时间步再次相同,等等

我的问题是我已经实现了一个版本,每个时间步我都创建和终止池,但这非常慢,甚至比顺序版本还慢。当我尝试实现一个在开始时只创建一次池的版本时,我在尝试调用 join() 时遇到断言错误。

这是我的代码

def log_result(result):

    tempx , tempb, u = result

    X[:,u,np.newaxis], b[:,u,np.newaxis] = tempx , tempb


workers =  mp.Pool(processes = 4) 
for t in range(p,T):

    count = 0 #==========This is only master's job=============
    for l in range(p):
        for k in range(4):
            gn[count]=train[t-l-1,k]
            count+=1
    G = G*v +  gn @ gn.T#==================================

    if __name__ == '__main__':
        for i in range(4):
            workers.apply_async(OULtraining, args=(train[t,i], X[:,i,np.newaxis], b[:,i,np.newaxis], i, gn), callback = log_result)


        workers.join()   

X和b是我想直接在master内存更新的矩阵

这里出了什么问题,我收到断言错误?

我可以用池实现我想要或不想要的东西吗?

您不能加入未先关闭的池,因为 join() 将等待工作进程终止,而不是作业完成(https://docs.python.org/3.6/library/multiprocessing.html 第 17.2.2.9 节)。

但是因为这会关闭池,这不是您想要的,所以您不能使用它。所以join就out了,需要自己实现一个"wait until all jobs completed"。

在没有繁忙循环的情况下执行此操作的一种方法是使用队列。您也可以使用有界信号量,但它们不适用于所有操作系统。

counter = 0
lock_queue = multiprocessing.Queue()
counter_lock = multiprocessing.Lock()

def log_result(result):

    tempx , tempb, u = result

    X[:,u,np.newaxis], b[:,u,np.newaxis] = tempx , tempb
    with counter_lock:
        counter += 1
        if counter == 4:
            counter = 0
            lock_queue.put(42)



workers =  mp.Pool(processes = 4) 
for t in range(p,T):

    count = 0 #==========This is only master's job=============
    for l in range(p):
        for k in range(4):
            gn[count]=train[t-l-1,k]
            count+=1
    G = G*v +  gn @ gn.T#==================================

    if __name__ == '__main__':
        counter = 0
        for i in range(4):
            workers.apply_async(OULtraining, args=(train[t,i], X[:,i,np.newaxis], b[:,i,np.newaxis], i, gn), callback = log_result)


        lock_queue.get(block=True)

这会在提交作业之前重置全局计数器。一旦作业完成,您的回调就会增加一个全局计数器。当计数器达到 4(您的作业数)时,回调知道它已经处理了最后一个结果。然后在队列中发送一条虚拟消息。您的主程序正在 Queue.get() 处等待出现。

这允许您的主程序阻塞,直到所有作业完成,而无需关闭池。

如果您将 multiprocessing.Pool 替换为 concurrent.futures 中的 ProcessPoolExecutor,则可以跳过此部分并使用

concurrent.futures.wait(fs, timeout=None, return_when=ALL_COMPLETED)

阻塞,直到所有提交的任务都完成。从功能的角度来看,它们之间没有区别。 concurrent.futures 方法少了几行,但结果完全一样。