Python : 为什么 numba 需要更多时间?

Python : why numba takes more time?

我正在使用 list comprehensiontupleslist 收集数据。代码如下:

data = [result[0] for result in results] #results is a list of tuples and i take first element from each tuple.

这一切都很好。

最近我遇到了 numba 可以提高循环执行速度的模块?

所以我试了这个来测试时间:

import numba
from numba import literal_unroll
from datetime import datetime
import logging

numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

@numba.jit(nopython=True)
def loop_faster(results):
    for result in literal_unroll(results):
        print(result)
    
tuples = (1.1, "Hello", 1, "World", "Tuple-1")

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

loop_faster(tuples)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

for result in tuples:
        print(result)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

我为 literal_unroll 推荐了这个 link : https://numba.pydata.org/numba-doc/dev/reference/pysupported.html

然而,for 循环似乎比 numba 方法执行得更好。

以上程序的结果:

2021-03-04 10:51:36.385
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.234
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.236

为什么会出现这种行为? numba 几乎用了 10 多秒

对于我的案例,要从元组的第 n 个元素中形成一个列表,我该如何使用 numba 模块来实现?

原因很简单,在function的第一个运行上花了大量时间将Ayour代码编译成C++代码和机器码,也就是在做numba的JIT。

所以你必须给你的@numba.jit装饰器添加参数cache = True来预缓存编译版本。您还必须在测量时间之前调用一次 运行 以确保编译。此外,您还必须 运行 更多循环迭代以更精确地测量时间,仅 10 毫秒 运行 是不够的。

下面的代码做了上面提到的三件事。您可以看到 numba 提供了 5.5x 倍的加速。

我还修改了您的代码以获得稍微不同的逻辑,因为打印逻辑无法正确测量时间。 Numba 适用于计算量很大的代码,而不是用于打印到控制台。所以我只是创建了随机整数数组并计算了这个数组+1。这足以作为示例代码,您可以看到 Numba 运行 的速度要快得多。

Try it online!

import numba, logging, random
from datetime import datetime

numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

@numba.njit(cache = True)
def loop_faster(results, n):
    for i in range(n):
        res = []
        for result in numba.literal_unroll(results):
            res.append(result + 1)
    
t = tuple(random.randrange(1 << 20) for i in range(100))

loop_faster(t, 10) # pre-compile numba

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

loop_faster(t, 1 << 16)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

for i in range(1 << 16):
    res = []
    for result in t:
        res.append(result + 1)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

输出:

2021-03-04 12:18:04.491
2021-03-04 12:18:04.840
2021-03-04 12:18:06.774