multiprocessing.Pool.map 抛出内存错误

multiprocessing.Pool.map throws MemoryError

我正在重写一个强化学习框架,从串行代码执行到并行(多处理)以减少训练时间。它有效,但经过几个小时的训练后,抛出 MemoryError。我尝试在每个循环后添加 gc.collect,没有任何变化。

这是使用多处理的 for 循环:

for episode in episodes:
    env.episode = episode
    flex_list = [0,1,2]                                                                                          
    for machine in env.list_of_machines:                                                                            
        flex_plan = []                                                                                              
        for time_step in range(0,env.steplength):
            flex_plan.append(random.choice(flex_list))
        machine.flex_plan = flex_plan
    env.current_step = 0                                                                                            
    steps = []
    state = env.reset(restricted=True)                                                                              
    steps.append(state)

    # multiprocessing part, has condition to use a specific amount of CPUs or 'all' of them
    ####################################################
    func_part = partial(parallel_pool, episode=episode, episodes=episodes, env=env, agent=agent, state=state, log_data_qvalues=log_data_qvalues, log_data=log_data, steps=steps)
    if CPUs_used == 'all':
        mp.Pool().map(func_part, range(env.steplength-1))
    else:
        mp.Pool(CPUs_used).map(func_part, range(env.steplength-1))
    ############################################################
    # model is saved periodically, not only in the end
    save_interval = 100 #set episode interval to save models
    if (episode + 1) % save_interval == 0:
        agent.save_model(f'models/model_{filename}_{episode + 1}')
        print(f'model saved at episode {episode + 1}')

    plt.close()
    gc.collect()

26 集训练后的输出:

Episode: 26/100   Action: 1/11    Phase: 3/3    Measurement Count: 231/234   THD fake slack: 0.09487   Psoll: [0.02894068 0.00046048 0.         0.        ]    Laptime: 0.181
Episode: 26/100   Action: 1/11    Phase: 3/3    Measurement Count: 232/234   THD fake slack: 0.09488   Psoll: [0.02894068 0.00046048 0.         0.        ]    Laptime: 0.181
Episode: 26/100   Action: 1/11    Phase: 3/3    Measurement Count: 233/234   THD fake slack: 0.09489   Psoll: [0.02894068 0.00046048 0.         0.        ]    Laptime: 0.179
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/train.py", line 87, in <module>
    main()
  File "C:/Users/Artur/Desktop/RL_framework/train.py", line 77, in main
    duration = cf.training(episodes, env, agent, filename, topology=topology, multi_processing=multi_processing, CPUs_used=CPUs_used)
  File "C:\Users\Artur\Desktop\RL_framework\help_functions\custom_functions.py", line 166, in training
    save_interval = parallel_training(range(episodes), env, agent, log_data_qvalues, log_data, filename, CPUs_used)
  File "C:\Users\Artur\Desktop\RL_framework\help_functions\custom_functions.py", line 81, in parallel_training
    mp.Pool().map(func_part, range(env.steplength-1))
  File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 268, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 657, in get
    raise self._value
  File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 431, in _handle_tasks
    put(task)
  File "C:\Users\Artur\Anaconda\lib\multiprocessing\connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "C:\Users\Artur\Anaconda\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
MemoryError

有办法解决这个问题吗?

当您在循环中创建进程时,我相信您的记忆被阻塞了,因为您创建的进程在完成后仍处于挂起状态 运行。

来自documentaion

Warning: multiprocessing.pool objects have internal resources that need to be properly managed (like any other resource) by using the pool as a context manager or by calling close() and terminate() manually. Failure to do this can lead to the process hanging on finalization. Note that is not correct to rely on the garbage colletor to destroy the pool as CPython does not assure that the finalizer of the pool will be called (see object.del() for more information).

我建议您尝试稍微重构一下您的代码:

# set the CPUs_used to a desired number or None to use all available CPUs
with mp.Pool(processes=CPUs_used) as p:
    p.map(func_part, range(env.steplength-1))

或者您可以手动 .close().join(),以最适合您的编码风格为准。