在 jitted 函数中两次反转 numpy 数组的视图使函数 运行 更快
Reversing the view of a numpy array twice in a jitted function makes the function run faster
所以我正在测试同一功能的两个版本的速度;一种是两次反转 numpy 数组的视图,另一种是没有。代码如下:
import numpy as np
from numba import njit
@njit
def min_getter(arr):
if len(arr) > 1:
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
else:
return arr
@njit
def min_getter_rev1(arr1):
if len(arr1) > 1:
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
else:
return arr1
size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))
y_min = min_getter(y)
yrev_min = min_getter_rev1(y)
令人惊讶的是,有一个额外操作的那个在多个场合运行得稍微快一些。我在这两个函数上使用了大约 10 次 %timeit
;尝试了不同大小的数组,差异很明显(至少在我的电脑上是这样)。 min_getter
的运行时间约为:
2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(有时是2.33,有时是2.37,但绝不会低于2.30)
min_getter_rev1
的运行时间约为:
2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(有时2.25有时2.23,但很少超过2.30)
关于为什么以及如何发生的任何想法?速度差异大约增加了 4-6%,这在某些应用程序中可能是一个大问题。加速的底层机制可能有助于加速一些 jitted 代码
注1:我试过size=5000000,每个函数都测试了5-10次,差别就更明显了。跑得快的 23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
跑得慢的在 24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
注2:numpy
和numba
测试时的版本分别为1.16.5
和0.45.1
; python 版本为 3.7.4
; IPython
版本为 7.8.0
; PythonIDE使用的是spyder
。不同版本测试结果可能不同
TL;DR:第二个代码更快可能只是一个幸运的巧合。
检查生成的类型揭示了一个重要的区别:
- 在第一个示例中,您的
arr
被键入为 array(int32, 1d, C)
C 连续数组。
min_getter.inspect_types()
min_getter (array(int32, 1d, C),) <--- THIS IS THE IMPORTANT LINE
--------------------------------------------------------------------------------
# File: <>
# --- LINE 4 ---
# label 0
@njit
# --- LINE 5 ---
def min_getter(arr):
[...]
- 在第二个示例中,
arr
被键入为 array(int32, 1d, A)
,一个不知道它是否连续的数组。那是因为 [::-1]
returns 一个没有连续性信息的数组,一旦丢失,它就无法在一秒钟内恢复 [::-1]
.
>>> min_getter_rev1.inspect_types()
[...]
# --- LINE 18 ---
# arr1 = arg(0, name=arr1) :: array(int32, 1d, C)
# $const0.2 = const(NoneType, None) :: none
# $const0.3 = const(NoneType, None) :: none
# $const0.4 = const(int, -1) :: Literal[int](-1)
# [=11=].5 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# [=11=].6 = call [=11=].5($const0.2, $const0.3, $const0.4, func=[=11=].5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.4
# del $const0.3
# del $const0.2
# del [=11=].5
# [=11=].7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=[=11=].6) :: array(int32, 1d, A)
# del arr1
# del [=11=].6
# $const0.8 = const(NoneType, None) :: none
# $const0.9 = const(NoneType, None) :: none
# $const0.10 = const(int, -1) :: Literal[int](-1)
# [=11=].11 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# [=11=].12 = call [=11=].11($const0.8, $const0.9, $const0.10, func=[=11=].11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.9
# del $const0.8
# del $const0.10
# del [=11=].11
# [=11=].13 = static_getitem(value=[=11=].7, index=slice(None, None, -1), index_var=[=11=].12) :: array(int32, 1d, A)
# del [=11=].7
# del [=11=].12
# arr = [=11=].13 :: array(int32, 1d, A) <---- THIS IS THE IMPORTANT LINE
# del [=11=].13
arr = arr1[::-1][::-1]
[...]
(其余生成的代码几乎相同)
如果已知数组是连续的,索引和迭代应该会更快。但这不是我们在这种情况下观察到的 - 恰恰相反。
那么可能是什么原因呢?
Numba 本身使用 LLVM 来 "compile" jitted 代码。所以有一个实际的编译器参与,编译器可以进行优化。尽管 inspect_types()
检查的代码几乎相同,但实际的 LLVM/ASM 代码完全不同 inspect_llvm()
和 inspect_asm()
。因此,编译器(或 numba)能够在第二种情况下进行某种优化,这在第一种情况下是不可能的。或者应用于第一种情况的某些优化实际上使代码变得更糟。
然而,这意味着在第二种情况下我们只是 "got lucky"。这可能不是可以控制的,因为它取决于:
- numba 根据您的来源创建的类型,
- numba 内部使用的对这些类型进行操作的源代码
- 从这些类型和 numba 源代码生成的 LLVM 和
- 从该 LLVM 生成的 ASM。
有太多可以应用优化(或不应用优化)的活动部分。
有趣的事实:如果你扔掉外面的 if
s:
import numpy as np
from numba import njit
@njit
def min_getter(arr):
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
@njit
def min_getter_rev1(arr1):
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))
y_min = min_getter(y)
yrev_min = min_getter_rev1(y)
%timeit min_getter(y) # 2.29 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit min_getter_rev1(y) # 2.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
在那种情况下,没有 [::-1][::-1]
的速度更快。
所以如果你想让它可靠地更快:将 if len(arr) > 1
检查移到函数之外并且不要使用 [::-1][::-1]
因为在大多数情况下这会使函数 运行较慢(且可读性较差)!
所以我正在测试同一功能的两个版本的速度;一种是两次反转 numpy 数组的视图,另一种是没有。代码如下:
import numpy as np
from numba import njit
@njit
def min_getter(arr):
if len(arr) > 1:
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
else:
return arr
@njit
def min_getter_rev1(arr1):
if len(arr1) > 1:
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
else:
return arr1
size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))
y_min = min_getter(y)
yrev_min = min_getter_rev1(y)
令人惊讶的是,有一个额外操作的那个在多个场合运行得稍微快一些。我在这两个函数上使用了大约 10 次 %timeit
;尝试了不同大小的数组,差异很明显(至少在我的电脑上是这样)。 min_getter
的运行时间约为:
2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(有时是2.33,有时是2.37,但绝不会低于2.30)
min_getter_rev1
的运行时间约为:
2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(有时2.25有时2.23,但很少超过2.30)
关于为什么以及如何发生的任何想法?速度差异大约增加了 4-6%,这在某些应用程序中可能是一个大问题。加速的底层机制可能有助于加速一些 jitted 代码
注1:我试过size=5000000,每个函数都测试了5-10次,差别就更明显了。跑得快的 23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
跑得慢的在 24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
注2:numpy
和numba
测试时的版本分别为1.16.5
和0.45.1
; python 版本为 3.7.4
; IPython
版本为 7.8.0
; PythonIDE使用的是spyder
。不同版本测试结果可能不同
TL;DR:第二个代码更快可能只是一个幸运的巧合。
检查生成的类型揭示了一个重要的区别:
- 在第一个示例中,您的
arr
被键入为array(int32, 1d, C)
C 连续数组。
min_getter.inspect_types()
min_getter (array(int32, 1d, C),) <--- THIS IS THE IMPORTANT LINE
--------------------------------------------------------------------------------
# File: <>
# --- LINE 4 ---
# label 0
@njit
# --- LINE 5 ---
def min_getter(arr):
[...]
- 在第二个示例中,
arr
被键入为array(int32, 1d, A)
,一个不知道它是否连续的数组。那是因为[::-1]
returns 一个没有连续性信息的数组,一旦丢失,它就无法在一秒钟内恢复[::-1]
.
>>> min_getter_rev1.inspect_types()
[...]
# --- LINE 18 ---
# arr1 = arg(0, name=arr1) :: array(int32, 1d, C)
# $const0.2 = const(NoneType, None) :: none
# $const0.3 = const(NoneType, None) :: none
# $const0.4 = const(int, -1) :: Literal[int](-1)
# [=11=].5 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# [=11=].6 = call [=11=].5($const0.2, $const0.3, $const0.4, func=[=11=].5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.4
# del $const0.3
# del $const0.2
# del [=11=].5
# [=11=].7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=[=11=].6) :: array(int32, 1d, A)
# del arr1
# del [=11=].6
# $const0.8 = const(NoneType, None) :: none
# $const0.9 = const(NoneType, None) :: none
# $const0.10 = const(int, -1) :: Literal[int](-1)
# [=11=].11 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# [=11=].12 = call [=11=].11($const0.8, $const0.9, $const0.10, func=[=11=].11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.9
# del $const0.8
# del $const0.10
# del [=11=].11
# [=11=].13 = static_getitem(value=[=11=].7, index=slice(None, None, -1), index_var=[=11=].12) :: array(int32, 1d, A)
# del [=11=].7
# del [=11=].12
# arr = [=11=].13 :: array(int32, 1d, A) <---- THIS IS THE IMPORTANT LINE
# del [=11=].13
arr = arr1[::-1][::-1]
[...]
(其余生成的代码几乎相同)
如果已知数组是连续的,索引和迭代应该会更快。但这不是我们在这种情况下观察到的 - 恰恰相反。
那么可能是什么原因呢?
Numba 本身使用 LLVM 来 "compile" jitted 代码。所以有一个实际的编译器参与,编译器可以进行优化。尽管 inspect_types()
检查的代码几乎相同,但实际的 LLVM/ASM 代码完全不同 inspect_llvm()
和 inspect_asm()
。因此,编译器(或 numba)能够在第二种情况下进行某种优化,这在第一种情况下是不可能的。或者应用于第一种情况的某些优化实际上使代码变得更糟。
然而,这意味着在第二种情况下我们只是 "got lucky"。这可能不是可以控制的,因为它取决于:
- numba 根据您的来源创建的类型,
- numba 内部使用的对这些类型进行操作的源代码
- 从这些类型和 numba 源代码生成的 LLVM 和
- 从该 LLVM 生成的 ASM。
有太多可以应用优化(或不应用优化)的活动部分。
有趣的事实:如果你扔掉外面的 if
s:
import numpy as np
from numba import njit
@njit
def min_getter(arr):
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
@njit
def min_getter_rev1(arr1):
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min
for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result
size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))
y_min = min_getter(y)
yrev_min = min_getter_rev1(y)
%timeit min_getter(y) # 2.29 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit min_getter_rev1(y) # 2.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
在那种情况下,没有 [::-1][::-1]
的速度更快。
所以如果你想让它可靠地更快:将 if len(arr) > 1
检查移到函数之外并且不要使用 [::-1][::-1]
因为在大多数情况下这会使函数 运行较慢(且可读性较差)!