如何在多个函数中最佳地使用 numba?

How to use numba optimally accross multiple functions?

假设我有两个函数

def my_sub1(a):
    return a + 2

def my_main(a):
    a += 1
    b = mysub1(a)
    return b

我想使用像 Numba 这样的即时编译器来加快它们的速度。这会比我将所有内容重构为一个函数要慢吗

def my_main(a):
    a += 1
    b = a + 2
    return b

因为 Numba 可以在第二种情况下进行更深入的优化?当然我的实际功能要复杂得多。

如果 my_sub1 函数被调用不止一次,整个情况也会变得更加困难 - 重构(和维护将成为一种拖累)。 Numba 如何解决这个问题?

Tl;dr: Numba 能够内联其他 Numba 函数,并且它仅在使用本机类型时执行相对高级的过程间优化(在这种情况下,两个函数都同样快),但不适用于 Numpy 数组。


您可以分析 Numba 生成的汇编代码,以检查这两个函数是如何优化的。这是一个整数的例子:

import numba as nb

@nb.njit('int64(int64)')
def my_sub1(a):
    return a + 2

@nb.njit('int64(int64)')
def my_main(a):
    a += 1
    b = my_sub1(a)
    return b

open('my_sub1.asm', 'w').write(list(my_sub1.inspect_asm().values())[0])
open('my_main.asm', 'w').write(list(my_main.inspect_asm().values())[0])

这会产生两个程序集文件。如果比较这两个文件,您会发现唯一的实际区别(除了不同的名称)是第一个执行 addq , %rdx 而第二个执行 addq , %rdx。这意味着 Numba 成功地在 my_main 中内联了对 my_sub1 的调用并合并了求和。下面是汇编代码的重要部分:

_ZN8__main__12my_sub113Ex:
    addq    , %rdx
    movq    %rdx, (%rdi)
    xorl    %eax, %eax
    retq

_ZN8__main__12my_main14Ex:
    addq    , %rdx
    movq    %rdx, (%rdi)
    xorl    %eax, %eax
    retq

对于 64 位浮点数,只要使用 fastmath=True 结果是相同的,因为浮点加法不是关联的。

关于 Numpy 数组,生成的代码非常庞大,很难比较这两个代码。但是,my_sub1 函数似乎不再内联,Numba 似乎无法合并 Numpy 计算(生成的代码中存在用于两个数组求和的两个不同的矢量化循环)。请注意,这与许多 C/C++ 编译器所做的类似。因此,最好在代码的性能关键部分自行内联函数。