Numba 字典:JIT() 装饰器中的签名
Numba dictionary: signature in JIT() decorator
我的函数将一个 numpy 数组列表和一个字典(或一个字典列表)作为输入参数,returns 一个值列表。 numpy 数组的列表很长,并且数组可能具有不同的形状。虽然我可以单独传递 numpy 数组,但出于管理目的,我真的很想形成一个 numpy 数组元组并将它们原样传递到我的函数中。
没有字典(根据 numba >=0.43 专门形成)整个设置工作正常 - 请参见下面的脚本。因为输入和输出的结构是元组形式,JIT 需要签名——没有签名它无法判断数据结构的类型。但是,无论我如何尝试将我的字典 'd' 声明到 JIT 装饰器中,我都无法使脚本正常工作。
如果有想法或解决方案,请提供帮助。
非常感谢
'''python:
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
@njit( 'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ))' )
def somefunction(lst_arr):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ,prod)
return result
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
print(a,b)
print(somefunction(arg))
# ~~ The Dict.empty() constructs a typed dictionary.
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,)
d['k1'] = 1.5
d['k2'] = 0.5
'''
我希望将 'd'-dictionary 传递给 'somefunction' 并在内部使用字典键...表单示例如下:result = (summ * d['k1'], prod * d['k2'])
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
@njit( 'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ), Dict)' )
def somefunction(lst_arr, mydict):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ*mydict['k1'],prod*mydict['k2'])
return result
# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
# ~~ Input dictionary for the function
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64)
d['k1'] = 1.5
d['k2'] = 0.5
# ~~ Run function and print results
print(somefunction(arg, d))
我使用的是 0.45.1
版本。你可以简单地传递字典而不必在字典中声明类型:
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64[:],
)
d['k1'] = np.arange(5) + 1.0
d['k2'] = np.arange(5) + 11.0
# Numba will infer the type on it's own.
@njit
def somefunction2(d):
prod = 1
# I am assuming you want sum of second array and product of second
result = (d['k2'].sum(), d['k1'].prod())
return result
print(somefunction(d))
# Output : (65.0, 120.0)
作为参考,您查看官方文档this example。
更新:
在你的情况下,你可以简单地让 jit
自己推断类型并且它应该工作,以下代码对我有用:
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
from numba.types import DictType
# Let jit infer the types on it's own
@njit
def somefunction(lst_arr, mydict):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ*mydict['k1'],prod*mydict['k2'])
return result
# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(10)+11.0 #<--------------- This is of different shape
arg = (a,b)
# ~~ Input dictionary for the function
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64)
d['k1'] = 1.5
d['k2'] = 0.5
# This works now
print(somefunction(arg, d))
可以看官方文档here:
Unless necessary, it is recommended to let Numba infer argument types by using the signature-less variant of @jit.
我尝试了多种方法,但这是唯一可以解决您指定的问题的方法。
我的函数将一个 numpy 数组列表和一个字典(或一个字典列表)作为输入参数,returns 一个值列表。 numpy 数组的列表很长,并且数组可能具有不同的形状。虽然我可以单独传递 numpy 数组,但出于管理目的,我真的很想形成一个 numpy 数组元组并将它们原样传递到我的函数中。 没有字典(根据 numba >=0.43 专门形成)整个设置工作正常 - 请参见下面的脚本。因为输入和输出的结构是元组形式,JIT 需要签名——没有签名它无法判断数据结构的类型。但是,无论我如何尝试将我的字典 'd' 声明到 JIT 装饰器中,我都无法使脚本正常工作。 如果有想法或解决方案,请提供帮助。
非常感谢
'''python:
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
@njit( 'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ))' )
def somefunction(lst_arr):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ,prod)
return result
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
print(a,b)
print(somefunction(arg))
# ~~ The Dict.empty() constructs a typed dictionary.
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,)
d['k1'] = 1.5
d['k2'] = 0.5
'''
我希望将 'd'-dictionary 传递给 'somefunction' 并在内部使用字典键...表单示例如下:result = (summ * d['k1'], prod * d['k2'])
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
@njit( 'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ), Dict)' )
def somefunction(lst_arr, mydict):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ*mydict['k1'],prod*mydict['k2'])
return result
# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
# ~~ Input dictionary for the function
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64)
d['k1'] = 1.5
d['k2'] = 0.5
# ~~ Run function and print results
print(somefunction(arg, d))
我使用的是 0.45.1
版本。你可以简单地传递字典而不必在字典中声明类型:
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64[:],
)
d['k1'] = np.arange(5) + 1.0
d['k2'] = np.arange(5) + 11.0
# Numba will infer the type on it's own.
@njit
def somefunction2(d):
prod = 1
# I am assuming you want sum of second array and product of second
result = (d['k2'].sum(), d['k1'].prod())
return result
print(somefunction(d))
# Output : (65.0, 120.0)
作为参考,您查看官方文档this example。
更新:
在你的情况下,你可以简单地让 jit
自己推断类型并且它应该工作,以下代码对我有用:
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
from numba.types import DictType
# Let jit infer the types on it's own
@njit
def somefunction(lst_arr, mydict):
arr1, arr2 = lst_arr
summ = 0
prod = 1
for i in arr2:
summ += i
for j in arr1:
prod *= j
result = (summ*mydict['k1'],prod*mydict['k2'])
return result
# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(10)+11.0 #<--------------- This is of different shape
arg = (a,b)
# ~~ Input dictionary for the function
d = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64)
d['k1'] = 1.5
d['k2'] = 0.5
# This works now
print(somefunction(arg, d))
可以看官方文档here:
Unless necessary, it is recommended to let Numba infer argument types by using the signature-less variant of @jit.
我尝试了多种方法,但这是唯一可以解决您指定的问题的方法。