为什么 Numba 不改进这个递归函数
Why Numba doesn't improve this recursive function
我有一个结构非常简单的 true/false 值数组:
# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
我想遍历这个数组,输出发生变化的地方(true变成false或者相反)。为此,我汇总了两种不同的方法:
- 递归二分查找(看所有值是否相同,如果不相同,一分为二,再递归)
- 纯迭代搜索(遍历所有元素并与 previous/next 比较)
两个版本都给出了我想要的结果,但是 Numba 对一个版本的影响比另一个版本更大。使用 300k 值的虚拟数组,这里是性能结果:
Performance results with array of 300k elements
- pure Python binary-search runs in 11 ms
- pure Python iterative-search runs in 1.1 s (100x slower than binary-search)
- Numba binary-search runs in 5 ms (2 times faster than pure Python equivalent)
- Numba iterative-search runs in 900 µs (1,200 times faster than pure Python equivalent)
因此,当使用 Numba 时,binary_search 比 iterative_search 慢 5 倍,而理论上它应该快 100 倍(预计 运行 在 9 微秒内如果它被正确加速)。
如何使 Numba 像加速迭代搜索一样加速二进制搜索?
这两种方法的代码(连同示例 position
数组)在这个 public 要点上可用:https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f
注意:Numba 在对象模式下不是 运行ning binary_search()
,因为当提到 nopython=True
时,它不会抱怨并愉快地编译函数。
要点是,只有使用 Python 机制的逻辑部分可以加速——通过用一些等效的 C 逻辑替换它,从而消除 [=17] 的大部分复杂性(和灵活性) =] 运行时(我想这就是 Numba 所做的)。
NumPy 操作中的所有繁重工作已经在 C 中实现并且非常简单(因为 NumPy 数组是包含常规 C 类型的连续内存块)所以 Numba 只能剥离与 Python 机器接口的部分.
您的 "binary search" 算法做了更多的工作,并且在处理它时更多地使用了 NumPy 的矢量运算,因此可以通过这种方式加速的更少。
主要问题是您没有进行同类比较。
您提供的不是同一算法的迭代和递归版本。
您提出了两种根本不同的算法,恰好是 recursive/iterative.
特别是您在递归方法中更多地使用 NumPy 内置函数,所以难怪这两种方法存在如此惊人的差异。当您避免使用 NumPy 内置函数时,Numba JITting 更有效也就不足为奇了。
最终,递归算法似乎效率较低,因为在迭代方法避免的 np.all()
和 np.any()
调用中存在一些 hidden 嵌套循环,因此即使如果您要用纯 Python 编写所有代码以便更有效地使用 Numba 进行加速,则递归方法会更慢。
一般来说,迭代方法比递归 等效方法 更快,因为它们避免了函数调用开销(与纯 Python 相比,JIT 加速函数的开销最小那些)。
所以我建议不要尝试以递归形式重写算法,只是发现它更慢。
编辑
假设一个简单的 np.diff()
就可以解决问题,Numba 仍然非常有用:
import numpy as np
import numba as nb
@nb.jit
def diff(arr):
n = arr.size
result = np.empty(n - 1, dtype=arr.dtype)
for i in range(n - 1):
result[i] = arr[i + 1] ^ arr[i]
return result
positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True
%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop
使用 Numba 方法快 13 倍(当然,在此测试中,里程可能会有所不同)。
使用np.diff
就可以找到数值变化的位置,不需要运行更复杂的算法,或者使用numba
:
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False, True, False, False, False, True, False, False])
这行得通,因为 False - True == -1
和 np.bool(-1) == True
。
它在我的电池供电(= 由于节能模式而节流)和几年前的旧笔记本电脑上表现很好:
In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)
In [53]: %timeit np.diff(positions)
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我想在 numba
中编写自己的差异应该会产生类似的性能。
编辑:最后一个陈述是错误的,我使用 numba
实现了一个简单的 diff 函数,它比 numpy
快了 10 倍以上(但它显然也有很多功能较少,但应该足以完成此任务):
@numba.njit
def ndiff(x):
s = x.size - 1
r = np.empty(s, dtype=x.dtype)
for i in range(s):
r[i] = x[i+1] - x[i]
return r
In [68]: np.all(ndiff(positions) == np.diff(positions))
Out[68]: True
In [69]: %timeit ndiff(positions)
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
我有一个结构非常简单的 true/false 值数组:
# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
我想遍历这个数组,输出发生变化的地方(true变成false或者相反)。为此,我汇总了两种不同的方法:
- 递归二分查找(看所有值是否相同,如果不相同,一分为二,再递归)
- 纯迭代搜索(遍历所有元素并与 previous/next 比较)
两个版本都给出了我想要的结果,但是 Numba 对一个版本的影响比另一个版本更大。使用 300k 值的虚拟数组,这里是性能结果:
Performance results with array of 300k elements
- pure Python binary-search runs in 11 ms
- pure Python iterative-search runs in 1.1 s (100x slower than binary-search)
- Numba binary-search runs in 5 ms (2 times faster than pure Python equivalent)
- Numba iterative-search runs in 900 µs (1,200 times faster than pure Python equivalent)
因此,当使用 Numba 时,binary_search 比 iterative_search 慢 5 倍,而理论上它应该快 100 倍(预计 运行 在 9 微秒内如果它被正确加速)。
如何使 Numba 像加速迭代搜索一样加速二进制搜索?
这两种方法的代码(连同示例 position
数组)在这个 public 要点上可用:https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f
注意:Numba 在对象模式下不是 运行ning binary_search()
,因为当提到 nopython=True
时,它不会抱怨并愉快地编译函数。
要点是,只有使用 Python 机制的逻辑部分可以加速——通过用一些等效的 C 逻辑替换它,从而消除 [=17] 的大部分复杂性(和灵活性) =] 运行时(我想这就是 Numba 所做的)。
NumPy 操作中的所有繁重工作已经在 C 中实现并且非常简单(因为 NumPy 数组是包含常规 C 类型的连续内存块)所以 Numba 只能剥离与 Python 机器接口的部分.
您的 "binary search" 算法做了更多的工作,并且在处理它时更多地使用了 NumPy 的矢量运算,因此可以通过这种方式加速的更少。
主要问题是您没有进行同类比较。 您提供的不是同一算法的迭代和递归版本。 您提出了两种根本不同的算法,恰好是 recursive/iterative.
特别是您在递归方法中更多地使用 NumPy 内置函数,所以难怪这两种方法存在如此惊人的差异。当您避免使用 NumPy 内置函数时,Numba JITting 更有效也就不足为奇了。
最终,递归算法似乎效率较低,因为在迭代方法避免的 np.all()
和 np.any()
调用中存在一些 hidden 嵌套循环,因此即使如果您要用纯 Python 编写所有代码以便更有效地使用 Numba 进行加速,则递归方法会更慢。
一般来说,迭代方法比递归 等效方法 更快,因为它们避免了函数调用开销(与纯 Python 相比,JIT 加速函数的开销最小那些)。 所以我建议不要尝试以递归形式重写算法,只是发现它更慢。
编辑
假设一个简单的 np.diff()
就可以解决问题,Numba 仍然非常有用:
import numpy as np
import numba as nb
@nb.jit
def diff(arr):
n = arr.size
result = np.empty(n - 1, dtype=arr.dtype)
for i in range(n - 1):
result[i] = arr[i + 1] ^ arr[i]
return result
positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True
%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop
使用 Numba 方法快 13 倍(当然,在此测试中,里程可能会有所不同)。
使用np.diff
就可以找到数值变化的位置,不需要运行更复杂的算法,或者使用numba
:
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False, True, False, False, False, True, False, False])
这行得通,因为 False - True == -1
和 np.bool(-1) == True
。
它在我的电池供电(= 由于节能模式而节流)和几年前的旧笔记本电脑上表现很好:
In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)
In [53]: %timeit np.diff(positions)
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我想在 numba
中编写自己的差异应该会产生类似的性能。
编辑:最后一个陈述是错误的,我使用 numba
实现了一个简单的 diff 函数,它比 numpy
快了 10 倍以上(但它显然也有很多功能较少,但应该足以完成此任务):
@numba.njit
def ndiff(x):
s = x.size - 1
r = np.empty(s, dtype=x.dtype)
for i in range(s):
r[i] = x[i+1] - x[i]
return r
In [68]: np.all(ndiff(positions) == np.diff(positions))
Out[68]: True
In [69]: %timeit ndiff(positions)
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)