Python: 如何用 ast 重载运算符
Python: How to overload operator with ast
我的问题是重新定义 +
-operator evaluating expression with ast.
我有一个表达式列表,使用 eval():
很容易解决
>>> expr = '1+2*3**4/5'
>>> print(eval(expr))
33.4
但我喜欢像这样重新定义列表和字典的 +
-运算符(加法):
expr = '[1,2,3]+[4,5,6]'
eval 的常规结果是
[1, 2, 3, 4, 5, 6]
但我想要
[5, 7, 9]
就像在 R 语言中一样。
同样适用于这样的词典:
expr = "{'a':1, 'b':2} + {'a':3, 'b':4}"
我想要
{'a': 4, 'b': 6}
简而言之,我想取代普通的 add 函数,即当操作数是 list 或 dict 时正确的操作。
我尝试使用 ast
和 NodeTransformer
但没有成功。有人可以帮助我吗?
制作您自己的列表class并在其上定义加法运算符:
class MyKindOfList(list):
def __add__(self, other):
return MyKindOfList(a + b for a, b in zip(self, other))
那么你可以这样做:
x = MyKindOfList([1, 2, 3])
y = MyKindOfList([4, 5, 6])
print (x + y) # prints [5, 7, 9]
即使使用 ast
模块,也不能重载内置 类 的 __add__
方法(如 list
和 dict
)。但是,您 可以 ,将所有添加项 x + y
重写为 your_custom_addition_function(x, y)
.
等函数调用
本质上,这是一个 3 步过程:
- 用
ast.parse
解析输入表达式。
- 使用
NodeTransformer
重写对函数调用的所有添加。
- 解析你自定义加法函数的源码,添加到步骤1得到的抽象语法树中
代码
import ast
def overload_add(syntax_tree):
# rewrite all additions to calls to our addition function
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if not isinstance(node.op, ast.Add):
node.left = lhs
node.right = rhs
return node
name = ast.Name('__custom_add', ast.Load())
args = [lhs, rhs]
kwargs = []
return ast.Call(name, args, kwargs)
syntax_tree = SumTransformer().visit(syntax_tree)
syntax_tree = ast.fix_missing_locations(syntax_tree)
# inject the custom addition function into the sytnax tree
code = '''
def __custom_add(lhs, rhs):
if isinstance(lhs, list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip(lhs, rhs)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key, 0), rhs.get(key, 0)) for key in keys}
return lhs + rhs
'''
add_func = ast.parse(code).body[0]
syntax_tree.body.insert(0, add_func)
return syntax_tree
code = '''
print(1 + 2)
print([1, 2] + [3, 4])
print({'a': 1} + {'a': -2})
'''
syntax_tree = ast.parse(code)
syntax_tree = overload_add(syntax_tree)
codeobj = compile(syntax_tree, 'foo.py', 'exec')
exec(codeobj)
# output:
# 3
# [4, 6]
# {'a': -1}
注意事项
- 添加函数将被添加到名称为
__custom_add
的全局作用域 - 它可以像任何其他全局函数一样访问,并且可以被覆盖、隐藏、删除或以其他方式篡改。
根据 Aran-Fey 的建议并阅读 this link 的内容,我写了一个更具可读性的代码来解决问题
import ast
from itertools import zip_longest
def __custom_add(lhs, rhs):
if isinstance(lhs,list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip_longest(lhs, rhs, fillvalue=0)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key,0), rhs.get(key,0)) for key in keys}
return lhs + rhs
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
new_node = ast.Call(func=ast.Name(id='__custom_add', ctx=ast.Load()),
args=[node.left, node.right],
keywords = [],
starargs = None, kwargs= None
)
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
return new_node
return node
expr = [
'(2 + 3 * 4)/2',
'[1, 2] + [3, 4]',
"{'a': 1} + {'a': -2}"
]
for e in expr:
syntax_tree = ast.parse(e, mode='eval')
syntax_tree = SumTransformer().visit(syntax_tree)
res = eval(compile(syntax_tree, '<ast>', 'eval'))
print(res)
# results
# 7.0
# [4, 6]
# {'a': -1}
感谢所有帮助过我的人
我的问题是重新定义 +
-operator evaluating expression with ast.
我有一个表达式列表,使用 eval():
>>> expr = '1+2*3**4/5'
>>> print(eval(expr))
33.4
但我喜欢像这样重新定义列表和字典的 +
-运算符(加法):
expr = '[1,2,3]+[4,5,6]'
eval 的常规结果是
[1, 2, 3, 4, 5, 6]
但我想要
[5, 7, 9]
就像在 R 语言中一样。
同样适用于这样的词典:
expr = "{'a':1, 'b':2} + {'a':3, 'b':4}"
我想要
{'a': 4, 'b': 6}
简而言之,我想取代普通的 add 函数,即当操作数是 list 或 dict 时正确的操作。
我尝试使用 ast
和 NodeTransformer
但没有成功。有人可以帮助我吗?
制作您自己的列表class并在其上定义加法运算符:
class MyKindOfList(list):
def __add__(self, other):
return MyKindOfList(a + b for a, b in zip(self, other))
那么你可以这样做:
x = MyKindOfList([1, 2, 3])
y = MyKindOfList([4, 5, 6])
print (x + y) # prints [5, 7, 9]
即使使用 ast
模块,也不能重载内置 类 的 __add__
方法(如 list
和 dict
)。但是,您 可以 ,将所有添加项 x + y
重写为 your_custom_addition_function(x, y)
.
本质上,这是一个 3 步过程:
- 用
ast.parse
解析输入表达式。 - 使用
NodeTransformer
重写对函数调用的所有添加。 - 解析你自定义加法函数的源码,添加到步骤1得到的抽象语法树中
代码
import ast
def overload_add(syntax_tree):
# rewrite all additions to calls to our addition function
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if not isinstance(node.op, ast.Add):
node.left = lhs
node.right = rhs
return node
name = ast.Name('__custom_add', ast.Load())
args = [lhs, rhs]
kwargs = []
return ast.Call(name, args, kwargs)
syntax_tree = SumTransformer().visit(syntax_tree)
syntax_tree = ast.fix_missing_locations(syntax_tree)
# inject the custom addition function into the sytnax tree
code = '''
def __custom_add(lhs, rhs):
if isinstance(lhs, list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip(lhs, rhs)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key, 0), rhs.get(key, 0)) for key in keys}
return lhs + rhs
'''
add_func = ast.parse(code).body[0]
syntax_tree.body.insert(0, add_func)
return syntax_tree
code = '''
print(1 + 2)
print([1, 2] + [3, 4])
print({'a': 1} + {'a': -2})
'''
syntax_tree = ast.parse(code)
syntax_tree = overload_add(syntax_tree)
codeobj = compile(syntax_tree, 'foo.py', 'exec')
exec(codeobj)
# output:
# 3
# [4, 6]
# {'a': -1}
注意事项
- 添加函数将被添加到名称为
__custom_add
的全局作用域 - 它可以像任何其他全局函数一样访问,并且可以被覆盖、隐藏、删除或以其他方式篡改。
根据 Aran-Fey 的建议并阅读 this link 的内容,我写了一个更具可读性的代码来解决问题
import ast
from itertools import zip_longest
def __custom_add(lhs, rhs):
if isinstance(lhs,list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip_longest(lhs, rhs, fillvalue=0)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key,0), rhs.get(key,0)) for key in keys}
return lhs + rhs
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
new_node = ast.Call(func=ast.Name(id='__custom_add', ctx=ast.Load()),
args=[node.left, node.right],
keywords = [],
starargs = None, kwargs= None
)
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
return new_node
return node
expr = [
'(2 + 3 * 4)/2',
'[1, 2] + [3, 4]',
"{'a': 1} + {'a': -2}"
]
for e in expr:
syntax_tree = ast.parse(e, mode='eval')
syntax_tree = SumTransformer().visit(syntax_tree)
res = eval(compile(syntax_tree, '<ast>', 'eval'))
print(res)
# results
# 7.0
# [4, 6]
# {'a': -1}
感谢所有帮助过我的人