Python : 为什么 numba 需要更多时间?
Python : why numba takes more time?
我正在使用 list comprehension
从 tuples
的 list
收集数据。代码如下:
data = [result[0] for result in results] #results is a list of tuples and i take first element from each tuple.
这一切都很好。
最近我遇到了 numba
可以提高循环执行速度的模块?
所以我试了这个来测试时间:
import numba
from numba import literal_unroll
from datetime import datetime
import logging
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.jit(nopython=True)
def loop_faster(results):
for result in literal_unroll(results):
print(result)
tuples = (1.1, "Hello", 1, "World", "Tuple-1")
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(tuples)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for result in tuples:
print(result)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
我为 literal_unroll
推荐了这个 link : https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
然而,for 循环似乎比 numba
方法执行得更好。
以上程序的结果:
2021-03-04 10:51:36.385
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.234
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.236
为什么会出现这种行为? numba 几乎用了 10 多秒
对于我的案例,要从元组的第 n 个元素中形成一个列表,我该如何使用 numba
模块来实现?
原因很简单,在function的第一个运行上花了大量时间将Ayour代码编译成C++代码和机器码,也就是在做numba的JIT。
所以你必须给你的@numba.jit
装饰器添加参数cache = True
来预缓存编译版本。您还必须在测量时间之前调用一次 运行 以确保编译。此外,您还必须 运行 更多循环迭代以更精确地测量时间,仅 10 毫秒 运行 是不够的。
下面的代码做了上面提到的三件事。您可以看到 numba 提供了 5.5x
倍的加速。
我还修改了您的代码以获得稍微不同的逻辑,因为打印逻辑无法正确测量时间。 Numba 适用于计算量很大的代码,而不是用于打印到控制台。所以我只是创建了随机整数数组并计算了这个数组+1。这足以作为示例代码,您可以看到 Numba 运行 的速度要快得多。
import numba, logging, random
from datetime import datetime
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.njit(cache = True)
def loop_faster(results, n):
for i in range(n):
res = []
for result in numba.literal_unroll(results):
res.append(result + 1)
t = tuple(random.randrange(1 << 20) for i in range(100))
loop_faster(t, 10) # pre-compile numba
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(t, 1 << 16)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for i in range(1 << 16):
res = []
for result in t:
res.append(result + 1)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
输出:
2021-03-04 12:18:04.491
2021-03-04 12:18:04.840
2021-03-04 12:18:06.774
我正在使用 list comprehension
从 tuples
的 list
收集数据。代码如下:
data = [result[0] for result in results] #results is a list of tuples and i take first element from each tuple.
这一切都很好。
最近我遇到了 numba
可以提高循环执行速度的模块?
所以我试了这个来测试时间:
import numba
from numba import literal_unroll
from datetime import datetime
import logging
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.jit(nopython=True)
def loop_faster(results):
for result in literal_unroll(results):
print(result)
tuples = (1.1, "Hello", 1, "World", "Tuple-1")
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(tuples)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for result in tuples:
print(result)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
我为 literal_unroll
推荐了这个 link : https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
然而,for 循环似乎比 numba
方法执行得更好。
以上程序的结果:
2021-03-04 10:51:36.385
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.234
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.236
为什么会出现这种行为? numba 几乎用了 10 多秒
对于我的案例,要从元组的第 n 个元素中形成一个列表,我该如何使用 numba
模块来实现?
原因很简单,在function的第一个运行上花了大量时间将Ayour代码编译成C++代码和机器码,也就是在做numba的JIT。
所以你必须给你的@numba.jit
装饰器添加参数cache = True
来预缓存编译版本。您还必须在测量时间之前调用一次 运行 以确保编译。此外,您还必须 运行 更多循环迭代以更精确地测量时间,仅 10 毫秒 运行 是不够的。
下面的代码做了上面提到的三件事。您可以看到 numba 提供了 5.5x
倍的加速。
我还修改了您的代码以获得稍微不同的逻辑,因为打印逻辑无法正确测量时间。 Numba 适用于计算量很大的代码,而不是用于打印到控制台。所以我只是创建了随机整数数组并计算了这个数组+1。这足以作为示例代码,您可以看到 Numba 运行 的速度要快得多。
import numba, logging, random
from datetime import datetime
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.njit(cache = True)
def loop_faster(results, n):
for i in range(n):
res = []
for result in numba.literal_unroll(results):
res.append(result + 1)
t = tuple(random.randrange(1 << 20) for i in range(100))
loop_faster(t, 10) # pre-compile numba
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(t, 1 << 16)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for i in range(1 << 16):
res = []
for result in t:
res.append(result + 1)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
输出:
2021-03-04 12:18:04.491
2021-03-04 12:18:04.840
2021-03-04 12:18:06.774