为什么numba破坏了这个功能?

Why does numba break this function?

我最近听说了 numba,今天想测试一下。 我做了一个简单的程序,它在计时时获取数字的阶乘。

import time
from numba import jit

@jit()
def fact(args):
    x = 1
    for i in range(2, args + 1):
        x *= i
    return x

st = time.time()
x = fact(100000)
print(x)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)

出于某种原因,当存在 @jit 装饰器时,此代码仅输出“0”,但如果没有 numba,则代码可以正常工作。

为什么会发生这种情况,我该如何解决?

我相信您可能遇到了使用本机类型的 numba 问题。具体来说,本机 int32 值(最有可能的罪魁祸首)的范围有限,而 Python 整数将根据需要变大。

如果我在你的循环中添加一个打印语句来显示 x 和 运行 直到 x 为 0,输出为:

2
6
24
120
720
5040
40320
362880
3628800
39916800
479001600
1932053504
1278945280
2004310016
2004189184
-288522240
-898433024
109641728
-2102132736
-1195114496
-522715136
862453760
-775946240
2076180480
-1853882368
1484783616
-1375731712
-1241513984
1409286144
738197504
-2147483648
-2147483648
0

如您所见,这些都放在一个 int32 中,并且在值上到处跳来跳去。

另一方面,没有numba的输出是:

2
6
24
120
720
5040
40320
362880
3628800
39916800
479001600
6227020800
87178291200
1307674368000
20922789888000
355687428096000
6402373705728000
121645100408832000
2432902008176640000
51090942171709440000
1124000727777607680000
25852016738884976640000
620448401733239439360000
15511210043330985984000000
403291461126605635584000000
10888869450418352160768000000
304888344611713860501504000000
8841761993739701954543616000000
265252859812191058636308480000000
8222838654177922817725562880000000
263130836933693530167218012160000000
8683317618811886495518194401280000000
295232799039604140847618609643520000000
10333147966386144929666651337523200000000

从Python输出中取出6227020800行,即

1 0111 0011 0010 1000 1100 1100 0000 0000

二进制。修剪掉不适合 int32 的额外位会给你 1932053504,这正是你在 numba 输出中看到的。

import numpy as np
import numba
from numba import jit
from numpy import prod
import time


def factorial(n):
    print( prod(range(1,n+1)))
        
factorial(1)
@jit(nopython=True)
def fact(args):
    x = 1
    for i in range(2, args + 1):
        x *= i
    return x
x = 10
st = time.time()
y = fact(x)
print(y)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)

1

3628800

已用时间:0.15069580078125

它非常适用于像 10 这样的小值 x = 10 对于大值,您必须使用现有的 numpy 函数

@jit()
def factorial1(n):
   return(np.math.factorial(n));

st = time.time()
x = factorial1(100)
print(x)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)