为什么 Numba 不改进这个递归函数

Why Numba doesn't improve this recursive function

我有一个结构非常简单的 true/false 值数组:

# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)

我想遍历这个数组,输出发生变化的地方(true变成false或者相反)。为此,我汇总了两种不同的方法:

两个版本都给出了我想要的结果,但是 Numba 对一个版本的影响比另一个版本更大。使用 300k 值的虚拟数组,这里是性能结果:

Performance results with array of 300k elements

  • pure Python binary-search runs in 11 ms
  • pure Python iterative-search runs in 1.1 s (100x slower than binary-search)
  • Numba binary-search runs in 5 ms (2 times faster than pure Python equivalent)
  • Numba iterative-search runs in 900 µs (1,200 times faster than pure Python equivalent)

因此,当使用 Numba 时,binary_search 比 iterative_search 慢 5 倍,而理论上它应该快 100 倍(预计 运行 在 9 微秒内如果它被正确加速)。

如何使 Numba 像加速迭代搜索一样加速二进制搜索?

这两种方法的代码(连同示例 position 数组)在这个 public 要点上可用:https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f

注意:Numba 在对象模式下不是 运行ning binary_search(),因为当提到 nopython=True 时,它不会抱怨并愉快地编译函数。

要点是,只有使用 Python 机制的逻辑部分可以加速——通过用一些等效的 C 逻辑替换它,从而消除 [=17] 的大部分复杂性(和灵活性) =] 运行时(我想这就是 Numba 所做的)。

NumPy 操作中的所有繁重工作已经在 C 中实现并且非常简单(因为 NumPy 数组是包含常规 C 类型的连续内存块)所以 Numba 只能剥离与 Python 机器接口的部分.

您的 "binary search" 算法做了更多的工作,并且在处理它时更多地使用了 NumPy 的矢量运算,因此可以通过这种方式加速的更少。

主要问题是您没有进行同类比较。 您提供的不是同一算法的迭代和递归版本。 您提出了两种根本不同的算法,恰好是 recursive/iterative.

特别是您在递归方法中更多地使用 NumPy 内置函数,所以难怪这两种方法存在如此惊人的差异。当您避免使用 NumPy 内置函数时,Numba JITting 更有效也就不足为奇了。 最终,递归算法似乎效率较低,因为在迭代方法避免的 np.all()np.any() 调用中存在一些 hidden 嵌套循环,因此即使如果您要用纯 Python 编写所有代码以便更有效地使用 Numba 进行加速,则递归方法会更慢。

一般来说,迭代方法比递归 等效方法 更快,因为它们避免了函数调用开销(与纯 Python 相比,JIT 加速函数的开销最小那些)。 所以我建议不要尝试以递归形式重写算法,只是发现它更慢。


编辑

假设一个简单的 np.diff() 就可以解决问题,Numba 仍然非常有用:

import numpy as np
import numba as nb


@nb.jit
def diff(arr):
    n = arr.size
    result = np.empty(n - 1, dtype=arr.dtype)
    for i in range(n - 1):
        result[i] = arr[i + 1] ^ arr[i]
    return result


positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True


%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop

使用 Numba 方法快 13 倍(当然,在此测试中,里程可能会有所不同)。

使用np.diff就可以找到数值变化的位置,不需要运行更复杂的算法,或者使用numba:

positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False,  True, False, False, False,  True, False, False])

这行得通,因为 False - True == -1np.bool(-1) == True

它在我的电池供电(= 由于节能模式而节流)和几年前的旧笔记本电脑上表现很好:

In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)          

In [53]: %timeit np.diff(positions)                                             
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我想在 numba 中编写自己的差异应该会产生类似的性能。

编辑:最后一个陈述是错误的,我使用 numba 实现了一个简单的 diff 函数,它比 numpy 快了 10 倍以上(但它显然也有很多功能较少,但应该足以完成此任务):

@numba.njit 
def ndiff(x): 
    s = x.size - 1 
    r = np.empty(s, dtype=x.dtype) 
    for i in range(s): 
        r[i] = x[i+1] - x[i] 
    return r

In [68]: np.all(ndiff(positions) == np.diff(positions))                            
Out[68]: True

In [69]: %timeit ndiff(positions)                                               
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)