Numba - 无法推断 List() 的类型
Numba - cannot infer type for List()
我正在尝试使用 numba 来加速 python 包的模糊搜索功能。我的计划是先按顺序使用 njit,然后在没有达到目标时转向并行。所以我将库中的原始函数转换为 numba 支持的类型。我正在使用类型列表而不是普通的 python 列表。 numba 抛出错误 “无法推断变量的类型 'candidates',类型不精确:ListType[undefined]”。我对为什么会出现此错误感到困惑?这不是声明类型化列表变量的方式吗?
我是 numba 的新手,所以欢迎任何关于加速的替代有效方法的建议。
@njit
def make_char2first_subseq_index(subsequence, max_l_dist):
d = Dict.empty(
key_type=types.unicode_type,
value_type=numba.int64,
)
for (index, char) in list(enumerate(subsequence[:max_l_dist + 1])):
d[char] = index
return d
@njit
def find_near_matches_levenshtein_linear_programming(subsequence, sequence,
max_l_dist):
if not subsequence:
raise ValueError('Given subsequence is empty!')
subseq_len = len(subsequence)
def make_match(start, end, dist):
# return Match(start, end, dist, matched=sequence[start:end])
return str(start) + " " + str(end) + " " + str(dist) + " " + str(sequence[start:end])
if max_l_dist >= subseq_len:
for index in range(len(sequence) + 1):
return make_match(index, index, subseq_len)
# optimization: prepare some often used things in advance
char2first_subseq_index = make_char2first_subseq_index(subsequence,
max_l_dist)
candidates = List()
for index, char in enumerate(sequence):
# print('/n new loop and the character is ', char)
new_candidates = List()
idx_in_subseq = char2first_subseq_index.get(char, None)
# print("idx_in_subseq ", idx_in_subseq)
if idx_in_subseq is not None:
if idx_in_subseq + 1 == subseq_len:
return make_match(index, index + 1, idx_in_subseq)
else:
new_candidates.append(List(index, idx_in_subseq + 1, idx_in_subseq))
# print(candidates, " new candidates ", new_candidates)
for cand in candidates:
# if this sequence char is the candidate's next expected char
if subsequence[cand[1]] == char:
# if reached the end of the subsequence, return a match
if cand[1] + 1 == subseq_len:
return make_match(cand[0], index + 1, cand[2])
# otherwise, update the candidate's subseq_index and keep it
else:
new_candidates.append(List(cand[0], cand[1] + 1, cand[2]))
# if this sequence char is *not* the candidate's next expected char
else:
# we can try skipping a sequence or sub-sequence char (or both),
# unless this candidate has already skipped the maximum allowed
# number of characters
if cand[2] == max_l_dist:
continue
# add a candidate skipping a sequence char
new_candidates.append(List(cand[0], cand[1], cand[2] + 1))
if index + 1 < len(sequence) and cand[1] + 1 < subseq_len:
# add a candidate skipping both a sequence char and a
# subsequence char
new_candidates.append(List(cand[0], cand[1] + 1, cand[2] + 1))
# try skipping subsequence chars
for n_skipped in range(1, max_l_dist - cand[2] + 1):
# if skipping n_skipped sub-sequence chars reaches the end
# of the sub-sequence, yield a match
if cand[1] + n_skipped == subseq_len:
return make_match(cand[0], index + 1, cand[2] + n_skipped)
break
# otherwise, if skipping n_skipped sub-sequence chars
# reaches a sub-sequence char identical to this sequence
# char, add a candidate skipping n_skipped sub-sequence
# chars
elif subsequence[cand[1] + n_skipped] == char:
# if this is the last char of the sub-sequence, yield
# a match
if cand[1] + n_skipped + 1 == subseq_len:
return make_match(cand[0], index + 1,
cand[2] + n_skipped)
# otherwise add a candidate skipping n_skipped
# subsequence chars
else:
new_candidates.append(List(cand[0], cand[1] + 1 + n_skipped, cand[2] + n_skipped))
break
# note: if the above loop ends without a break, that means that
# no candidate could be added / yielded by skipping sub-sequence
# chars
candidates = new_candidates
for cand in candidates:
dist = cand[2] + subseq_len - cand[1]
if dist <= max_l_dist:
return make_match(cand[0], len(sequence), dist)
错误信息非常准确,指向了具体的问题。 Numba typed.List 使用同类数据类型,因此它需要知道类型。
您可以通过初始化创建类型化列表:
list_of_ints = nb.typed.List([1,2,3])
或者使用 empty_list()
工厂创建一个空的来声明它的类型:
empty_list_of_floats = nb.typed.List.empty_list(nb.f8)
或者创建一个空的并立即附加一个元素:
another_list_of_ints = nb.typed.List()
another_list_of_ints.append(1)
或任意组合:
list_of_lists_of_floats = nb.typed.List()
list_of_lists_of_floats.append(nb.typed.List.empty_list(nb.f8))
list_of_lists_of_floats[0].append(1)
我正在尝试使用 numba 来加速 python 包的模糊搜索功能。我的计划是先按顺序使用 njit,然后在没有达到目标时转向并行。所以我将库中的原始函数转换为 numba 支持的类型。我正在使用类型列表而不是普通的 python 列表。 numba 抛出错误 “无法推断变量的类型 'candidates',类型不精确:ListType[undefined]”。我对为什么会出现此错误感到困惑?这不是声明类型化列表变量的方式吗?
我是 numba 的新手,所以欢迎任何关于加速的替代有效方法的建议。
@njit
def make_char2first_subseq_index(subsequence, max_l_dist):
d = Dict.empty(
key_type=types.unicode_type,
value_type=numba.int64,
)
for (index, char) in list(enumerate(subsequence[:max_l_dist + 1])):
d[char] = index
return d
@njit
def find_near_matches_levenshtein_linear_programming(subsequence, sequence,
max_l_dist):
if not subsequence:
raise ValueError('Given subsequence is empty!')
subseq_len = len(subsequence)
def make_match(start, end, dist):
# return Match(start, end, dist, matched=sequence[start:end])
return str(start) + " " + str(end) + " " + str(dist) + " " + str(sequence[start:end])
if max_l_dist >= subseq_len:
for index in range(len(sequence) + 1):
return make_match(index, index, subseq_len)
# optimization: prepare some often used things in advance
char2first_subseq_index = make_char2first_subseq_index(subsequence,
max_l_dist)
candidates = List()
for index, char in enumerate(sequence):
# print('/n new loop and the character is ', char)
new_candidates = List()
idx_in_subseq = char2first_subseq_index.get(char, None)
# print("idx_in_subseq ", idx_in_subseq)
if idx_in_subseq is not None:
if idx_in_subseq + 1 == subseq_len:
return make_match(index, index + 1, idx_in_subseq)
else:
new_candidates.append(List(index, idx_in_subseq + 1, idx_in_subseq))
# print(candidates, " new candidates ", new_candidates)
for cand in candidates:
# if this sequence char is the candidate's next expected char
if subsequence[cand[1]] == char:
# if reached the end of the subsequence, return a match
if cand[1] + 1 == subseq_len:
return make_match(cand[0], index + 1, cand[2])
# otherwise, update the candidate's subseq_index and keep it
else:
new_candidates.append(List(cand[0], cand[1] + 1, cand[2]))
# if this sequence char is *not* the candidate's next expected char
else:
# we can try skipping a sequence or sub-sequence char (or both),
# unless this candidate has already skipped the maximum allowed
# number of characters
if cand[2] == max_l_dist:
continue
# add a candidate skipping a sequence char
new_candidates.append(List(cand[0], cand[1], cand[2] + 1))
if index + 1 < len(sequence) and cand[1] + 1 < subseq_len:
# add a candidate skipping both a sequence char and a
# subsequence char
new_candidates.append(List(cand[0], cand[1] + 1, cand[2] + 1))
# try skipping subsequence chars
for n_skipped in range(1, max_l_dist - cand[2] + 1):
# if skipping n_skipped sub-sequence chars reaches the end
# of the sub-sequence, yield a match
if cand[1] + n_skipped == subseq_len:
return make_match(cand[0], index + 1, cand[2] + n_skipped)
break
# otherwise, if skipping n_skipped sub-sequence chars
# reaches a sub-sequence char identical to this sequence
# char, add a candidate skipping n_skipped sub-sequence
# chars
elif subsequence[cand[1] + n_skipped] == char:
# if this is the last char of the sub-sequence, yield
# a match
if cand[1] + n_skipped + 1 == subseq_len:
return make_match(cand[0], index + 1,
cand[2] + n_skipped)
# otherwise add a candidate skipping n_skipped
# subsequence chars
else:
new_candidates.append(List(cand[0], cand[1] + 1 + n_skipped, cand[2] + n_skipped))
break
# note: if the above loop ends without a break, that means that
# no candidate could be added / yielded by skipping sub-sequence
# chars
candidates = new_candidates
for cand in candidates:
dist = cand[2] + subseq_len - cand[1]
if dist <= max_l_dist:
return make_match(cand[0], len(sequence), dist)
错误信息非常准确,指向了具体的问题。 Numba typed.List 使用同类数据类型,因此它需要知道类型。
您可以通过初始化创建类型化列表:
list_of_ints = nb.typed.List([1,2,3])
或者使用 empty_list()
工厂创建一个空的来声明它的类型:
empty_list_of_floats = nb.typed.List.empty_list(nb.f8)
或者创建一个空的并立即附加一个元素:
another_list_of_ints = nb.typed.List()
another_list_of_ints.append(1)
或任意组合:
list_of_lists_of_floats = nb.typed.List()
list_of_lists_of_floats.append(nb.typed.List.empty_list(nb.f8))
list_of_lists_of_floats[0].append(1)