检查调用脚本是否使用 "if __name__ == "__main__"(以符合多处理要求)

Check if calling script used "if __name__ == "__main__" (to comply with multiprocessing requirement)

我写了一个包,在其中一个函数中使用了 multiprocessing.Pool

由于这个原因,必须(如“主模块的安全导入”下的here中指定)最外层调用函数可以安全导入例如无需启动新进程。这通常是使用 if __name__ == "__main__": 语句实现的,如上面 link 中明确解释的那样。

我的理解(如有错误请指正)是multiprocessing引入了最外层的调用模块。因此,如果这不是“导入安全的”,这将启动一个新进程,该进程将再次导入最外层的模块,递归地依此类推,直到一切崩溃。

如果最外面的模块在启动 main 函数时不是“导入安全的”,它通常会挂起而不打印任何警告、错误、消息等。

由于使用 if __name__ == "__main__": 通常不是强制性的,并且用户通常并不总是知道包内使用的所有模块,所以我想在我的函数开始时检查用户是否遵守此规定要求,如果没有,提出 warning/error.

这可能吗?我该怎么做?

为了举例说明这一点,请考虑以下示例。

假设我开发了 my_module.py 并且我分享了它 online/in 我的公司。

# my_module.py
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max):
    with Pool(5) as p:
        print(p.map(f, range(x_max)))

如果用户(不是我)将自己的脚本写成:

# Script_of_a_good_user.py
from my_module import my_function

if __name__ == '__main__':
    my_function(10)

一切正常,输出按预期打印。

然而,如果粗心的用户将他的脚本写成:

# Script_of_a_careless_user.py
from my_module import my_function

my_function(10)

然后进程挂起,没有输出产生,但没有向用户发出错误消息或警告。

my_functionBEFORE 打开 Pool 之前有没有办法检查用户是否在其脚本中使用了 if __name__ == '__main__': 条件并且,如果没有,提出一个错误说它应该这样做?

注意:我认为这种行为只是 Windows 机器上的一个问题,其中 fork() 不可用,如 here 所述。

我认为您能做的最好的事情就是尝试执行代码并在失败时提供提示。像这样:

# my_module.py
import sys  # Use sys.stderr to print to the error stream.
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max):
    try:    
        with Pool(5) as p:
            print(p.map(f, range(x_max)))
    except RuntimeError as e:
        print("Whoops! Did you perhaps forget to put the code in `if __name__ == '__main__'`?", file=sys.stderr)
        raise e

这当然不是 100% 的解决方案,因为代码抛出 RuntimeError.

可能还有其他几个原因

如果它不引发 RuntimeError,一个丑陋的解决方案是明确强制用户输入模块名称。

# my_module.py
from multiprocessing import Pool

def f(x):
    return x*x

def my_function(x_max, module):
    """`module` must be set to `__name__`, for example `my_function(10, __name__)`"""
    if module == '__main__':    
        with Pool(5) as p:
            print(p.map(f, range(x_max)))
    else:
        raise Exception("This can only be called from the main module.")

并将其命名为:

# Script_of_a_careless_user.py
from my_module import my_function
my_function(10, __name__)

这对用户来说非常明确。

您可以使用 traceback 模块来检查堆栈并找到您要查找的信息。解析顶层框架,寻找代码中的主屏蔽

我假设当您使用 .pyc 文件并且无法访问源代码时这会失败,但我假设开发人员会在执行任何操作之前先以常规方式测试他们的代码一种包装,所以我认为可以安全地假设您的错误消息会在需要时打印出来。

带有详细消息的版本:

import traceback
import re

def called_from_main_shield():
    print("Calling introspect")
    tb = traceback.extract_stack()
    print(traceback.format_stack())
    print(f"line={tb[0].line} lineno={tb[0].lineno} file={tb[0].filename}")
    try:
        with open(tb[0].filename, mode="rt") as f:
            found_main_shield = False
            for i, line in enumerate(f):
                if re.search(r"__name__.*['\"]__main__['\"]", line):
                    found_main_shield = True
                if i == tb[0].lineno:
                    print(f"found_main_shield={found_main_shield}")
                    return found_main_shield
    except:
        print("Coulnd't inspect stack, let's pretend the code is OK...")
        return True

print(called_from_main_shield())

if __name__ == "__main__":
    print(called_from_main_shield())

在输出中,我们看到第一个调用 called_from_main_shield returns False,而第二个 returns True:

$ python3 introspect.py
Calling introspect
['  File "introspect.py", line 24, in <module>\n    print(called_from_main_shield())\n', '  File "introspect.py", lin
e 7, in called_from_main_shield\n    print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=24 file=introspect.py
found_main_shield=False
False
Calling introspect
['  File "introspect.py", line 27, in <module>\n    print(called_from_main_shield())\n', '  File "introspect.py", lin
e 7, in called_from_main_shield\n    print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=27 file=introspect.py
found_main_shield=True
True

更简洁的版本:

def called_from_main_shield():
    tb = traceback.extract_stack()
    try:
        with open(tb[0].filename, mode="rt") as f:
            found_main_shield = False
            for i, line in enumerate(f):
                if re.search(r"__name__.*['\"]__main__['\"]", line):
                    found_main_shield = True
                if i == tb[0].lineno:
                    return found_main_shield
    except:
        return True

现在,像我这样使用re.search()虽然不是很优雅,但应该足够可靠了。警告:因为我在我的主脚本中定义了这个函数,所以我必须确保该行不匹配自身,这就是为什么我使用 ['\"] 来匹配引号而不是使用像 [=20= 这样更简单的 RE ].无论您选择什么,只要确保它足够灵活以匹配该代码的所有合法变体,这就是我的目标。