如何通过查看 AST 来判断来自生成器的正常 Python 函数?
How to tell normal Python function from generator by looking at the AST?
我需要检测 ast.FunctionDef
in Python 3 AST 是普通函数定义还是生成器定义。
我需要遍历 body 并寻找 ast.Yield
-s 还是有更简单的方法?
有一种偷偷摸摸的方法是用 compile
编译 AST 实例。代码对象附加了几个标志,其中之一是 'GENERATOR'
,您可以使用它们来区分它们。当然,这取决于某些编译标志,因此它不能真正跨 CPython 版本或实现移植
例如,使用非生成器函数:
func = """
def spam_func():
print("spam")
"""
# Create the AST instance for it
m = ast.parse(func)
# get the function code
# co_consts[0] is used because `m` is
# compiled as a module and we want the
# function object
fc = compile(m, '', 'exec').co_consts[0]
# get a string of the flags and
# check for membership
from dis import pretty_flags
'GENERATOR' in pretty_flags(fc.co_flags) # False
同样,对于 spam_gen
生成器,您会得到:
gen = """
def spam_gen():
yield "spammy"
"""
m = ast.parse(gen)
gc = compile(m, '', 'exec').co_consts[0]
'GENERATOR' in pretty_flags(gc.co_flags) # True
虽然这可能比您需要的更隐蔽,但遍历 AST 是另一个可行的选择,它可能更易于理解和移植。
如果您有一个函数对象而不是 AST,您始终可以使用 func.__code__.co_flags
:
执行相同的检查
def spam_gen():
yield "spammy"
from dis import pretty_flags
print(pretty_flags(spam_gen.__code__.co_flags))
# 'OPTIMIZED, NEWLOCALS, GENERATOR, NOFREE'
遍历 AST 比看起来更难——使用编译器可能是可行的方法。下面是一个示例,说明为什么寻找 Yield 节点并不像听起来那么简单。
>>> s1 = 'def f():\n yield'
>>> any(isinstance(node, ast.Yield) for node in ast.walk(ast.parse(s1)))
True
>>> dis.pretty_flags(compile(s1, '', 'exec').co_consts[0].co_flags)
'OPTIMIZED, NEWLOCALS, GENERATOR, NOFREE'
>>> s2 = 'def f():\n def g():\n yield'
>>> any(isinstance(node, ast.Yield) for node in ast.walk(ast.parse(s2)))
True
>>> dis.pretty_flags(compile(s2, '', 'exec').co_consts[0].co_flags)
'OPTIMIZED, NEWLOCALS, NOFREE'
AST 方法可能需要使用 NodeVisitor 来排除函数和 lambda 主体。
我需要检测 ast.FunctionDef
in Python 3 AST 是普通函数定义还是生成器定义。
我需要遍历 body 并寻找 ast.Yield
-s 还是有更简单的方法?
有一种偷偷摸摸的方法是用 compile
编译 AST 实例。代码对象附加了几个标志,其中之一是 'GENERATOR'
,您可以使用它们来区分它们。当然,这取决于某些编译标志,因此它不能真正跨 CPython 版本或实现移植
例如,使用非生成器函数:
func = """
def spam_func():
print("spam")
"""
# Create the AST instance for it
m = ast.parse(func)
# get the function code
# co_consts[0] is used because `m` is
# compiled as a module and we want the
# function object
fc = compile(m, '', 'exec').co_consts[0]
# get a string of the flags and
# check for membership
from dis import pretty_flags
'GENERATOR' in pretty_flags(fc.co_flags) # False
同样,对于 spam_gen
生成器,您会得到:
gen = """
def spam_gen():
yield "spammy"
"""
m = ast.parse(gen)
gc = compile(m, '', 'exec').co_consts[0]
'GENERATOR' in pretty_flags(gc.co_flags) # True
虽然这可能比您需要的更隐蔽,但遍历 AST 是另一个可行的选择,它可能更易于理解和移植。
如果您有一个函数对象而不是 AST,您始终可以使用 func.__code__.co_flags
:
def spam_gen():
yield "spammy"
from dis import pretty_flags
print(pretty_flags(spam_gen.__code__.co_flags))
# 'OPTIMIZED, NEWLOCALS, GENERATOR, NOFREE'
遍历 AST 比看起来更难——使用编译器可能是可行的方法。下面是一个示例,说明为什么寻找 Yield 节点并不像听起来那么简单。
>>> s1 = 'def f():\n yield'
>>> any(isinstance(node, ast.Yield) for node in ast.walk(ast.parse(s1)))
True
>>> dis.pretty_flags(compile(s1, '', 'exec').co_consts[0].co_flags)
'OPTIMIZED, NEWLOCALS, GENERATOR, NOFREE'
>>> s2 = 'def f():\n def g():\n yield'
>>> any(isinstance(node, ast.Yield) for node in ast.walk(ast.parse(s2)))
True
>>> dis.pretty_flags(compile(s2, '', 'exec').co_consts[0].co_flags)
'OPTIMIZED, NEWLOCALS, NOFREE'
AST 方法可能需要使用 NodeVisitor 来排除函数和 lambda 主体。