Numba - 无法推断 List() 的类型

Numba - cannot infer type for List()

我正在尝试使用 numba 来加速 python 包的模糊搜索功能。我的计划是先按顺序使用 njit,然后在没有达到目标时转向并行。所以我将库中的原始函数转换为 numba 支持的类型。我正在使用类型列表而不是普通的 python 列表。 numba 抛出错误 “无法推断变量的类型 'candidates',类型不精确:ListType[undefined]”。我对为什么会出现此错误感到困惑?这不是声明类型化列表变量的方式吗?

我是 numba 的新手,所以欢迎任何关于加速的替代有效方法的建议。

@njit
def make_char2first_subseq_index(subsequence, max_l_dist):
    d = Dict.empty(
        key_type=types.unicode_type,
        value_type=numba.int64,
    )
    for (index, char) in list(enumerate(subsequence[:max_l_dist + 1])):
        d[char] = index
    return d


@njit
def find_near_matches_levenshtein_linear_programming(subsequence, sequence,
                                                     max_l_dist):
    if not subsequence:
        raise ValueError('Given subsequence is empty!')

    subseq_len = len(subsequence)

    def make_match(start, end, dist):
        # return Match(start, end, dist, matched=sequence[start:end])
        return str(start) + " " + str(end) + " " + str(dist) + " " + str(sequence[start:end])

    if max_l_dist >= subseq_len:
        for index in range(len(sequence) + 1):
            return make_match(index, index, subseq_len)

    # optimization: prepare some often used things in advance
    char2first_subseq_index = make_char2first_subseq_index(subsequence,
                                                           max_l_dist)

    candidates = List()
    for index, char in enumerate(sequence):
        # print('/n new loop and the character is ', char)
        new_candidates = List()

        idx_in_subseq = char2first_subseq_index.get(char, None)
        # print("idx_in_subseq ", idx_in_subseq)
        if idx_in_subseq is not None:
            if idx_in_subseq + 1 == subseq_len:
                return make_match(index, index + 1, idx_in_subseq)
            else:
                new_candidates.append(List(index, idx_in_subseq + 1, idx_in_subseq))

        # print(candidates, " new candidates ", new_candidates)
        for cand in candidates:
            # if this sequence char is the candidate's next expected char
            if subsequence[cand[1]] == char:
                # if reached the end of the subsequence, return a match
                if cand[1] + 1 == subseq_len:
                    return make_match(cand[0], index + 1, cand[2])
                # otherwise, update the candidate's subseq_index and keep it
                else:
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2]))

            # if this sequence char is *not* the candidate's next expected char
            else:
                # we can try skipping a sequence or sub-sequence char (or both),
                # unless this candidate has already skipped the maximum allowed
                # number of characters
                if cand[2] == max_l_dist:
                    continue

                # add a candidate skipping a sequence char
                new_candidates.append(List(cand[0], cand[1], cand[2] + 1))

                if index + 1 < len(sequence) and cand[1] + 1 < subseq_len:
                    # add a candidate skipping both a sequence char and a
                    # subsequence char
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2] + 1))

                # try skipping subsequence chars
                for n_skipped in range(1, max_l_dist - cand[2] + 1):
                    # if skipping n_skipped sub-sequence chars reaches the end
                    # of the sub-sequence, yield a match
                    if cand[1] + n_skipped == subseq_len:
                        return make_match(cand[0], index + 1, cand[2] + n_skipped)
                        break
                    # otherwise, if skipping n_skipped sub-sequence chars
                    # reaches a sub-sequence char identical to this sequence
                    # char, add a candidate skipping n_skipped sub-sequence
                    # chars
                    elif subsequence[cand[1] + n_skipped] == char:
                        # if this is the last char of the sub-sequence, yield
                        # a match
                        if cand[1] + n_skipped + 1 == subseq_len:
                            return make_match(cand[0], index + 1,
                                             cand[2] + n_skipped)
                        # otherwise add a candidate skipping n_skipped
                        # subsequence chars
                        else:
                            new_candidates.append(List(cand[0], cand[1] + 1 + n_skipped, cand[2] + n_skipped))
                        break
                # note: if the above loop ends without a break, that means that
                # no candidate could be added / yielded by skipping sub-sequence
                # chars

        candidates = new_candidates

    for cand in candidates:
        dist = cand[2] + subseq_len - cand[1]
        if dist <= max_l_dist:
            return make_match(cand[0], len(sequence), dist)

错误信息非常准确,指向了具体的问题。 Numba typed.List 使用同类数据类型,因此它需要知道类型。

您可以通过初始化创建类型化列表:

list_of_ints = nb.typed.List([1,2,3])

或者使用 empty_list() 工厂创建一个空的来声明它的类型:

empty_list_of_floats = nb.typed.List.empty_list(nb.f8)

或者创建一个空的并立即附加一个元素:

another_list_of_ints = nb.typed.List()
another_list_of_ints.append(1)

或任意组合:

list_of_lists_of_floats = nb.typed.List()
list_of_lists_of_floats.append(nb.typed.List.empty_list(nb.f8))
list_of_lists_of_floats[0].append(1)