`multiprocessing.Pool` 在独立函数中

`multiprocessing.Pool` in stand-alone functions

由于我的 last version 没有收到任何回复,请用更具体的问题重新提出这个问题。

我正在尝试制作一个可导入的函数来对具有一系列(长)时间历史的数据帧进行平稳小波变换。实际的理论并不重要(我什至可能没有完全正确地使用它),重要的部分是我将时间历史分解成块并使用 multiprocessing.Pool 将它们提供给多个线程。

import pandas as pd
import numpy as np
import pywt
from multiprocessing import Pool
import functools

def swt_block(arr, level = 8, wvlt = 'haar'):
    block_length = arr.shape[0]
    if block_length == 2**level:
        d = pywt.swt(arr, wvlt, axis = 0)
    elif block_length < 2**level:
        arr_ = np.pad(arr, 
                      ((0, 2**level - block_length), (0,0)), 
                      'constant', constant_values = 0)
        d = pywt.swt(arr_, wvlt, axis = 0)
    else:
        raise ValueError('block of length ' + str(arr.shape[0]) + ' too large for swt of level ' + str(level))
    out = []
    for lvl in d:
        for coeff in lvl:
            out.append(coeff)
    return np.concatenate(out, axis = -1)[:block_length]


def swt(df, wvlt = 'haar', level = 8, processors = 4):
    block_length = 2**level
    with Pool(processors) as p:
        data = p.map(functools.partial(swt_block, level = level, wvlt = wvlt), 
                     [i.values for _, i in df.groupby(np.arange(len(df)) // block_length)])
    data = np.concatenate(data, axis = 0) 
    header = pd.MultiIndex.from_product([list(range(level)),
                                     [0, 1],
                                     df.columns], 
                                     names = ['level', 'coef', 'channel'])
    df_out = pd.DataFrame(data, index = df.index, colummns = header)

    return df_out

我之前已经在独立脚本中完成了此操作,因此如果第二个函数只是包裹在 if __name__ == '__main__': 中的裸代码,代码就可以工作,如果我添加一个类似的块,代码确实可以在脚本中工作脚本的结尾。但是,如果我在解释器中导入甚至只是 运行 以上内容,然后执行

df_swt = swt(df)

事情无限期地挂起。我确定它是 multiprocessing 上的某种护栏,以防止我对线程做一些愚蠢的事情,但我真的不想将这段代码复制到一堆其他脚本中。包括其他标签,以防它们以某种方式成为罪魁祸首。

首先要明确一点,您正在创建多个 进程 而不是 线程。如果您对线程特别感兴趣,请将导入更改为:from multiprocessing.dummy import Pool.

来自multiprocessing introduction:

multiprocessing is a package that supports spawning processes using an API similar to the threading module.

来自multprocessing.dummy section:

multiprocessing.dummy replicates the API of multiprocessing but is no more than a wrapper around the threading module.

现在,我能够重现您的问题(根据您之前的链接问题)并且确实发生了同样的事情。 运行 互动 shell 东西干脆挂了。

然而,有趣的是,运行通过windowscmd,这个错误层出不穷出现在屏幕上:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

所以,作为一个疯狂的猜测,我添加到 importing 模块:

if __name__ == "__main__":

而且…………成功了!

为了消除疑虑,我将 post 在此提供我使用的确切文件,以便您(希望)重新创建解决方案...

multi.py中:

from multiprocessing import Pool

def __foo(x):
    return x**2

def bar(list_of_inputs):
    with Pool() as p:
        out = p.map(__foo, list_of_inputs)
    print(out)

if __name__ == "__main__":
    bar(list(range(50)))

tests.py中:

from multi import bar

l = list(range(50))

if __name__ == "__main__":
    bar(l)

当 运行 这两个文件中的任何一个时输出(在 shell 中和通过 cmd):

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961, 1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521, 1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401]

更新: 我在文档中找不到任何具体证据来说明 为什么 这个问题会发生,但是,显然它有与创建新流程和 importing of the main module.

有关

正如本答案开头所讨论的,您似乎打算在您的意图中使用线程,但不知道您正在使用进程。如果确实如此,那么使用实际线程将解决您的问题,并且不需要您更改任何内容,除了导入语句(更改为:from multiprocessing.dummy import Pool)。使用线程,您在主模块和导入模块中都没有定义 if __name__ == "__main__": 的限制。所以这应该有效:

multi.py中:

from multiprocessing.dummy import Pool

def __foo(x):
    return x**2

def bar(list_of_inputs):
    with Pool() as p:
        out = p.map(__foo, list_of_inputs)
    print(out)

if __name__ == "__main__":
    bar(list(range(50)))

tests.py中:

from multi import bar

l = list(range(50))

bar(l)

我真的希望这可以帮助您解决问题,如果可以,请告诉我。