在 jitted 函数中两次反转 numpy 数组的视图使函数 运行 更快

Reversing the view of a numpy array twice in a jitted function makes the function run faster

所以我正在测试同一功能的两个版本的速度;一种是两次反转 numpy 数组的视图,另一种是没有。代码如下:

import numpy as np
from numba import njit

@njit
def min_getter(arr):

    if len(arr) > 1:
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr

@njit
def min_getter_rev1(arr1):

    if len(arr1) > 1:
        arr = arr1[::-1][::-1]
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr1
size = 500000
x = np.arange(size)   
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

令人惊讶的是,有一个额外操作的那个在多个场合运行得稍微快一些。我在这两个函数上使用了大约 10 次 %timeit;尝试了不同大小的数组,差异很明显(至少在我的电脑上是这样)。 min_getter的运行时间约为:

2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时是2.33,有时是2.37,但绝不会低于2.30)

min_getter_rev1 的运行时间约为:

2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时2.25有时2.23,但很少超过2.30)


关于为什么以及如何发生的任何想法?速度差异大约增加了 4-6%,这在某些应用程序中可能是一个大问题。加速的底层机制可能有助于加速一些 jitted 代码


注1:我试过size=5000000,每个函数都测试了5-10次,差别就更明显了。跑得快的 23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 跑得慢的在 24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

注2:numpynumba测试时的版本分别为1.16.50.45.1; python 版本为 3.7.4IPython 版本为 7.8.0; PythonIDE使用的是spyder。不同版本测试结果可能不同

TL;DR:第二个代码更快可能只是一个幸运的巧合。


检查生成的类型揭示了一个重要的区别:

  • 在第一个示例中,您的 arr 被键入为 array(int32, 1d, C) C 连续数组。
min_getter.inspect_types()

min_getter (array(int32, 1d, C),)  <--- THIS IS THE IMPORTANT LINE
--------------------------------------------------------------------------------
# File: <>
# --- LINE 4 --- 
# label 0

@njit

# --- LINE 5 --- 

def min_getter(arr):

[...]
  • 在第二个示例中,arr 被键入为 array(int32, 1d, A),一个不知道它是否连续的数组。那是因为 [::-1] returns 一个没有连续性信息的数组,一旦丢失,它就无法在一秒钟内恢复 [::-1].
>>> min_getter_rev1.inspect_types()

[...]

    # --- LINE 18 --- 
    #   arr1 = arg(0, name=arr1)  :: array(int32, 1d, C)
    #   $const0.2 = const(NoneType, None)  :: none
    #   $const0.3 = const(NoneType, None)  :: none
    #   $const0.4 = const(int, -1)  :: Literal[int](-1)
    #   [=11=].5 = global(slice: <class 'slice'>)  :: Function(<class 'slice'>)
    #   [=11=].6 = call [=11=].5($const0.2, $const0.3, $const0.4, func=[=11=].5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>
    #   del $const0.4
    #   del $const0.3
    #   del $const0.2
    #   del [=11=].5
    #   [=11=].7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=[=11=].6)  :: array(int32, 1d, A)
    #   del arr1
    #   del [=11=].6
    #   $const0.8 = const(NoneType, None)  :: none
    #   $const0.9 = const(NoneType, None)  :: none
    #   $const0.10 = const(int, -1)  :: Literal[int](-1)
    #   [=11=].11 = global(slice: <class 'slice'>)  :: Function(<class 'slice'>)
    #   [=11=].12 = call [=11=].11($const0.8, $const0.9, $const0.10, func=[=11=].11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>
    #   del $const0.9
    #   del $const0.8
    #   del $const0.10
    #   del [=11=].11
    #   [=11=].13 = static_getitem(value=[=11=].7, index=slice(None, None, -1), index_var=[=11=].12)  :: array(int32, 1d, A)
    #   del [=11=].7
    #   del [=11=].12
    #   arr = [=11=].13  :: array(int32, 1d, A)  <---- THIS IS THE IMPORTANT LINE
    #   del [=11=].13

    arr = arr1[::-1][::-1]

[...]

(其余生成的代码几乎相同)

如果已知数组是连续的,索引和迭代应该会更快。但这不是我们在这种情况下观察到的 - 恰恰相反。

那么可能是什么原因呢?

Numba 本身使用 LLVM 来 "compile" jitted 代码。所以有一个实际的编译器参与,编译器可以进行优化。尽管 inspect_types() 检查的代码几乎相同,但实际的 LLVM/ASM 代码完全不同 inspect_llvm()inspect_asm()。因此,编译器(或 numba)能够在第二种情况下进行某种优化,这在第一种情况下是不可能的。或者应用于第一种情况的某些优化实际上使代码变得更糟。

然而,这意味着在第二种情况下我们只是 "got lucky"。这可能不是可以控制的,因为它取决于:

  • numba 根据您的来源创建的类型,
  • numba 内部使用的对这些类型进行操作的源代码
  • 从这些类型和 numba 源代码生成的 LLVM 和
  • 从该 LLVM 生成的 ASM。

有太多可以应用优化(或不应用优化)的活动部分。


有趣的事实:如果你扔掉外面的 ifs:

import numpy as np
from numba import njit

@njit
def min_getter(arr):
    result = np.empty(len(arr), dtype = arr.dtype)
    local_min = arr[0]
    result[0] = local_min

    for i in range(1,len(arr)):
        if arr[i] < local_min:
            local_min = arr[i]
        result[i] = local_min
    return result

@njit
def min_getter_rev1(arr1):
    arr = arr1[::-1][::-1]
    result = np.empty(len(arr), dtype = arr.dtype)
    local_min = arr[0]
    result[0] = local_min

    for i in range(1,len(arr)):
        if arr[i] < local_min:
            local_min = arr[i]
        result[i] = local_min
    return result

size = 500000
x = np.arange(size)   
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

%timeit min_getter(y)      # 2.29 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit min_getter_rev1(y) # 2.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在那种情况下,没有 [::-1][::-1] 的速度更快。

所以如果你想让它可靠地更快:将 if len(arr) > 1 检查移到函数之外并且不要使用 [::-1][::-1] 因为在大多数情况下这会使函数 运行较慢(且可读性较差)!