Python Numba 使用集合字典优化函数

Python Numba optimize function with dictionary of sets

我正在尝试执行 100K + 蒙特卡洛模拟并多次调用函数 simulate 来执行此操作。我正在调查使用 Nubma 来执行此操作并且可以使用一些帮助,特别是我需要执行集合的交集和并集。

这是我尝试优化的代码的最小工作示例。似乎 Numba 不支持集合字典,如果我要使用 numpy 数组字典,那么我无法访问 np.intersect1dnp.union1d 任何帮助将其转换为 Numba 将不胜感激,或者它可能不适合 Numba 优化?

import random
import numpy as np

from numba import types, jit
from numba.typed import Dict, List

n_positions = 100
position_coverage = 10
min_dist = 5

n_ids = 5

near_dict = {x: set(range(x, x + position_coverage)) 
                if x < n_positions - position_coverage else 
                set(range(x, n_positions)) for x in range(n_positions)}
# near_dict = {0: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 1: {1, 2, 3, 4, 5, 6, 7, ...
close_dict = {x: set(range(x, x + min_dist)) 
                if x < n_positions - position_coverage + min_dist else 
                set(range(x, n_positions)) for x in range(n_positions)}
# close_dict = {0: {0, 1, 2, 3, 4}, 1: {1, 2, 3, 4, 5}, 2: {2, 3, 4, 5, ...

near_dict = {k: near_dict[k] for k in random.sample(near_dict.keys(), n_ids)}

@jit(nopython=True)
def simulate(near_dict, close_dict):
    out_dict = {}
    to_pos = None
    to_close = set()
    for id, near in sorted(near_dict.items(), key=lambda x: random.random()):
        if to_pos is None:
            _ixs_to = near
        else:
            to_close |= close_dict[to_pos]
            _ixs_to = near - to_close
        if not _ixs_to:
            return {} # failed no reachable positions
        to_pos = random.choice(tuple(_ixs_to))
        out_dict[id] = to_pos
    return out_dict

if __name__ == '__main__':
    simulate(near_dict, close_dict)

转换为 Numba:

这是一项乏味的工作,因为 Numba 施加了许多限制。 post 这段代码一步一步肯定会更好,但是答案会太长。
简而言之,要将代码转换为 Numba 并使用 nopython 对其进行全面优化,您需要大致遵循这些规则。
免责声明:我不是 numba 专家,这部分只是基于观察:

  • 没有基本对象和 numpy 之外的对象。这意味着没有集合也没有字典。
  • 函数只允许 return 单个值,而不是数组。
  • 一切都需要输入或足够简单,以便 Numba 自动推断类型。

可能还有更多。我不太确定。所以基本上顾名思义,如果你走 nopython 路线,忘记你在 python。代码要么需要非常简单(就像在几个简单计算的循环中一样),要么你将不得不把代码砍掉。 (或者还有其他一些我还没有找到的技巧。)
编辑: 我发现了一个技巧 - 只要您事先声明所有类型,numba 现在即使在 nojit 中也支持字典。它还支持集合。但是它不能将集合存储在字典中,因此在您的情况下,您必须首先将集合存储在数组中,然后使用字典作为从 id 到索引的一种翻译 table 。并非不可能,但绝对不是直截了当的。

最有可能有效的代码。然而,它的一部分已被弃用,因为我放弃了试图摆脱最后一个 [] 以支持 numpy 数组。它也没有优化,我也没有测量性能。

import random
import numpy as np
from numba import jit
n_positions = 100
position_coverage = 10
min_dist = 5
n_ids = 5
near_keys = random.sample(range(n_positions), n_ids)


@jit(nopython=True)
    def simulate(near_keys):
    to_pos = -1
    _ixs_to = np.zeros(1)
    to_return = []
    for id in sorted(near_keys, key=lambda x: random.random()):
        if id < n_positions:
            near = np.arange(id, id + position_coverage)
        else:
            near = np.arange(id, n_positions)

        print(id, near)
        if to_pos == -1:
            _ixs_to = near
        else:
            to_close = []
            a = int(n_positions - position_coverage + min_dist)
            to_pos = int(to_pos)
        if to_pos < a:
            pos_values = np.arange(to_pos, to_pos + min_dist)
        else:
            pos_values = np.arange(to_pos, n_positions)

        for i in pos_values:
            if not i in to_close:
                to_close.append(i)

        _ixs_to = np.zeros(1, dtype=np.int64)
        for i in near:
            if not i in to_close:
                _ixs_to = np.append(_ixs_to, i)
        if _ixs_to[0] == 0:
            _ixs_to = _ixs_to[1:]

        if _ixs_to.ndim == 0:
            return [-1] # failed no reachable positions
        np.random.rand()
        to_pos = np.random.choice(_ixs_to)
        to_return.append(id)
        to_return.append(to_pos)
        return to_return

if __name__ == '__main__':
    print(simulate(near_keys))

那么,您相当不错的代码(如果可能是不必要的复杂代码)怎么了?

首先,我用生成函数替换了 sets
near_dict = {x: set(range(x, x + position_coverage)) if x < n_positions - position_coverage else set(range(x, n_positions)) for x in range(n_positions)}
变成了

def get_near(x):
    if x < n - position_coverage:
        return set(range(x, x + position_coverage))
    else:
        return set(range(x, n_positions)

然后我以不需要它们的方式重写了集合的并集和差集。这是通过遍历数组并检查要删除或添加的值来完成的。请注意,这是最有可能更慢的部分,除非纯 python 中的联合操作比我做的未优化的天真联合更慢。如果您想优化代码,这是开始的地方之一。

不幸的是,虽然 njit numba 允许参数为数组,但它 doesn't allow a return value of a function to be an array。因为我们需要数组,所以下一步是内联任何在原始代码中 returning 数组的函数。

在阅读文档时,我注意到了另外一件事:Numba 不允许在 nopython 中创建数组,在这种情况下会退回到对象模式。这意味着您需要回收数组而不是在上面的代码中重新创建它们才能充分利用此模式。

最后一步是解决所有类型不匹配的问题,并确保一切都能协同工作。结果是 Numba 的 nopython 愿意 运行 的代码,不幸的是,它相当混乱,几乎没有利用 Python 的任何便利。不幸的是,它也没有得到优化。但是,如果您希望应用 numba 并随后进行 cheap 并行化,那么这很可能是必要的。如果有人更熟悉 numba 纠正我并指出一个更用户友好的方向,我很乐意删除这个具有误导性的答案。

总而言之,虽然 Numba 绝对是一个强大的工具,但我不太确定将它用于您的项目是否合理,因为您的主要问题不是简单的重复数学运算,而是重复迭代逻辑比 Numba 更喜欢。

小提示:

  • 我会先分析和优化您的代码。在开始并行恶作剧之前分析代码会更容易。如果这证明不够 将是尝试的时刻并且:
  • 查看 multiprocessing python 库,如果这种加速对您来说足够了。如果这还不够,那么:
  • 尝试重写您已经为 numba 优化的代码,最后,如果这还不够:
  • 试试 C、Rust、C++、Java 或类似的东西。你可能会感到惊讶。

结语: 虽然我相当确定用我随后内联的函数替换集合的方式,但我不完全确定我做对了。建议谨慎。