将函数及其所有子函数传递给 njit
Pass a function and all its subfunctions into njit
所以我最近发现了 Numba,我对它感到非常惊讶。尝试时,我使用了一个 bubblesort 函数作为测试函数,但由于我的 bubblesort 函数调用了另一个函数,因此在调用 njit 时出现错误。
我解决了这个问题,首先在我的 bubblesort 子函数上调用 njit,然后让我的 bubblesort 调用 njit 子函数,它起作用了,但它迫使我在尝试比较时定义两个 bubblesort 函数。我想知道是否有另一种方法可以做到这一点。
这就是我正在做的事情:
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
bytaintill_njit = njit()(bytaintill)
def bubblesort(l):
not_done = True
while not_done:
not_done = bytaintill_njit(l)
return
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
bubblesort_njit = njit()(bubblesort)
为了扩展我的评论,您不需要定义新函数,但也可以将 jit-ed 版本映射到相同的名称。通常,最方便的方法是使用 @jit
装饰器(或 @njit
是 @jit(nopython=True)
的缩写)。
from numba import njit
@njit
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
@njit
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
出于基准测试目的,您可以简单地注释掉装饰器。如果您希望能够在 jit-ed 和 python 版本之间来回切换,您可以尝试这样的事情:
from numba import njit
do_jit = True # set to True or False
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
if do_jit:
bytaintill = njit()(bytaintill)
bubble = njit()(bubble)
所以我最近发现了 Numba,我对它感到非常惊讶。尝试时,我使用了一个 bubblesort 函数作为测试函数,但由于我的 bubblesort 函数调用了另一个函数,因此在调用 njit 时出现错误。
我解决了这个问题,首先在我的 bubblesort 子函数上调用 njit,然后让我的 bubblesort 调用 njit 子函数,它起作用了,但它迫使我在尝试比较时定义两个 bubblesort 函数。我想知道是否有另一种方法可以做到这一点。
这就是我正在做的事情:
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
bytaintill_njit = njit()(bytaintill)
def bubblesort(l):
not_done = True
while not_done:
not_done = bytaintill_njit(l)
return
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
bubblesort_njit = njit()(bubblesort)
为了扩展我的评论,您不需要定义新函数,但也可以将 jit-ed 版本映射到相同的名称。通常,最方便的方法是使用 @jit
装饰器(或 @njit
是 @jit(nopython=True)
的缩写)。
from numba import njit
@njit
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
@njit
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
出于基准测试目的,您可以简单地注释掉装饰器。如果您希望能够在 jit-ed 和 python 版本之间来回切换,您可以尝试这样的事情:
from numba import njit
do_jit = True # set to True or False
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
if do_jit:
bytaintill = njit()(bytaintill)
bubble = njit()(bubble)