Python: 如何用 ast 重载运算符

Python: How to overload operator with ast

我的问题是重新定义 +-operator evaluating expression with ast. 我有一个表达式列表,使用 eval():

很容易解决
>>> expr = '1+2*3**4/5'
>>> print(eval(expr))
33.4

但我喜欢像这样重新定义列表和字典的 +-运算符(加法):

expr = '[1,2,3]+[4,5,6]'

eval 的常规结果是

[1, 2, 3, 4, 5, 6]

但我想要

[5, 7, 9]

就像在 R 语言中一样。

同样适用于这样的词典:

expr = "{'a':1, 'b':2} + {'a':3, 'b':4}"

我想要

{'a': 4, 'b': 6}

简而言之,我想取代普通的 add 函数,即当操作数是 list 或 dict 时正确的操作。

我尝试使用 astNodeTransformer 但没有成功。有人可以帮助我吗?

制作您自己的列表class并在其上定义加法运算符:

class MyKindOfList(list):
    def __add__(self, other):
        return MyKindOfList(a + b for a, b in zip(self, other))

那么你可以这样做:

x = MyKindOfList([1, 2, 3])
y = MyKindOfList([4, 5, 6])

print (x + y)  # prints [5, 7, 9]

即使使用 ast 模块,也不能重载内置 类 的 __add__ 方法(如 listdict)。但是,您 可以 ,将所有添加项 x + y 重写为 your_custom_addition_function(x, y).

等函数调用

本质上,这是一个 3 步过程:

  1. ast.parse解析输入表达式。
  2. 使用 NodeTransformer 重写对函数调用的所有添加。
  3. 解析你自定义加法函数的源码,添加到步骤1得到的抽象语法树中

代码

import ast


def overload_add(syntax_tree):
    # rewrite all additions to calls to our addition function
    class SumTransformer(ast.NodeTransformer):
        def visit_BinOp(self, node):
            lhs = self.visit(node.left)
            rhs = self.visit(node.right)

            if not isinstance(node.op, ast.Add):
                node.left = lhs
                node.right = rhs
                return node

            name = ast.Name('__custom_add', ast.Load())
            args = [lhs, rhs]
            kwargs = []
            return ast.Call(name, args, kwargs)

    syntax_tree = SumTransformer().visit(syntax_tree)
    syntax_tree = ast.fix_missing_locations(syntax_tree)

    # inject the custom addition function into the sytnax tree
    code = '''
def __custom_add(lhs, rhs):
    if isinstance(lhs, list) and isinstance(rhs, list):
        return [__custom_add(l, r) for l, r in zip(lhs, rhs)]

    if isinstance(lhs, dict) and isinstance(rhs, dict):
        keys = lhs.keys() | rhs.keys()
        return {key: __custom_add(lhs.get(key, 0), rhs.get(key, 0)) for key in keys}

    return lhs + rhs
    '''
    add_func = ast.parse(code).body[0]
    syntax_tree.body.insert(0, add_func)

    return syntax_tree

code = '''
print(1 + 2)
print([1, 2] + [3, 4])
print({'a': 1} + {'a': -2})
'''
syntax_tree = ast.parse(code)
syntax_tree = overload_add(syntax_tree)
codeobj = compile(syntax_tree, 'foo.py', 'exec')
exec(codeobj)

# output:
# 3
# [4, 6]
# {'a': -1}

注意事项

  • 添加函数将被添加到名称为 __custom_add 的全局作用域 - 它可以像任何其他全局函数一样访问,并且可以被覆盖、隐藏、删除或以其他方式篡改。

根据 Aran-Fey 的建议并阅读 this link 的内容,我写了一个更具可读性的代码来解决问题

import ast
from itertools import zip_longest

def __custom_add(lhs, rhs):
    if isinstance(lhs,list) and isinstance(rhs, list):
        return [__custom_add(l, r) for l, r in zip_longest(lhs, rhs, fillvalue=0)]

    if isinstance(lhs, dict) and isinstance(rhs, dict):
        keys = lhs.keys() | rhs.keys()
        return {key: __custom_add(lhs.get(key,0), rhs.get(key,0)) for key in keys}

    return lhs + rhs

class SumTransformer(ast.NodeTransformer):

    def visit_BinOp(self, node):
        if isinstance(node.op, ast.Add):
            new_node = ast.Call(func=ast.Name(id='__custom_add', ctx=ast.Load()),
                            args=[node.left, node.right],
                            keywords = [],
                            starargs = None, kwargs= None
                            )
            ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return new_node

        return node

expr = [
    '(2 + 3 * 4)/2',
    '[1, 2] + [3, 4]',
    "{'a': 1} + {'a': -2}"
    ]


for e in expr:
    syntax_tree = ast.parse(e, mode='eval')
    syntax_tree = SumTransformer().visit(syntax_tree)
    res = eval(compile(syntax_tree, '<ast>', 'eval'))
    print(res)

# results

# 7.0
# [4, 6]
# {'a': -1}

感谢所有帮助过我的人