如何确定 numba 的 prange 是否真的正常工作?

How to determine if numba's prange actually works correctly?

在另一个问答中() I made a comment regarding the correctness of using prange about this code (of ):

from numba import njit, prange

@njit
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

评论是:

I wouldn't recommend parallelizing a loop that isn't pure. In this case the running variable makes it impure. There are 4 possible outcomes: (1)numba decides that it cannot parallelize it and just process the loop as if it was cumsum instead of prange (2)it can lift the variable outside the loop and use parallelization on the remainder (3)numba incorrectly inserts synchronization between the parallel executions and the result may be bogus (4)numba inserts the necessary synchronizations around running which may impose more overhead than you gain by parallelizing it in the first place

以及后面的补充:

Of course both the running and cumsum variable make the loop "impure", not just the running variable as stated in the previous comment

然后有人问我:

This might sound like a silly question, but how can I figure out which of the 4 things it did and improve it? I would really like to become better with numba!

鉴于它可能对未来的读者有用,我决定在这里创建一个自答问答。剧透:我无法真正回答产生 4 个结果中的哪一个(或者 numba 是否产生完全不同的结果)的问题,所以我强烈鼓励其他答案。

TL;DR:首先:prangerange 相同,除非您将平行添加到 jit,例如 njit(parallel=True)。如果您尝试这样做,您会看到关于 "unsupported reduction" 的异常 - 那是因为 Numba 将 prange 的范围限制为 "pure" 循环 "impure loops" 使用 numba 支持的缩减 并把确保它属于这些类别之一的责任放在用户身上。

这在numbas prange (version 0.42)的文档中有明确说明:

1.10.2. Explicit Parallel Loops

Another feature of this code transformation pass is support for explicit parallel loops. One can use Numba’s prange instead of range to specify that a loop can be parallelized. The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.

注释中称为 "impure" 的内容在该文档中称为 "cross iteration dependencies"。这样一个"cross-iteration dependency"就是一个在循环之间变化的变量。一个简单的例子是:

def func(n):
    a = 0
    for i in range(n):
        a += 1
    return a

这里的变量 a 取决于它在循环开始之前的值 已经执行了多少次循环迭代。这就是 "cross iteration dependency" 或 "impure" 循环的含义。

显式并行化此类循环时的问题是迭代是并行执行的,但每次迭代都需要知道其他迭代在做什么。不这样做会导致错误的结果。

让我们暂时假设 prange 会产生 4 个工人,我们将 4 作为 n 传递给函数。一个完全天真的实现会做什么?

Worker 1 starts, gets a i = 1 from `prange`, and reads a = 0
Worker 2 starts, gets a i = 2 from `prange`, and reads a = 0
Worker 3 starts, gets a i = 3 from `prange`, and reads a = 0
Worker 1 executed the loop and sets `a = a + 1` (=> 1)
Worker 3 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 starts, gets a i = 4 from `prange`, and reads a = 2
Worker 2 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 executed the loop and sets `a = a + 1` (=> 3)

=> Loop ended, function return 3

不同的 worker 读取、执行和写入 a 的顺序可以是任意的,这只是一个例子。它还可能(意外地)产生正确的结果!这通常称为 Race condition.

更复杂的 prange 会做什么来识别存在这种交叉迭代依赖性?

共有三个选项:

  • 只是不要将其并行化。
  • 实施工人共享变量的机制。这里的典型示例是 Locks(这会产生高开销)。
  • 认识到这是一种可以并行化的缩减。

鉴于我对 numba 文档的理解(再次重复):

The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.

Numba 可以:

  • 如果它是一个已知的缩减,那么使用模式来并行化它
  • 如果不是已知的缩减,则抛出异常

不幸的是,不清楚 "supported reductions" 是什么。但是文档暗示它是对循环体中的先前值进行运算的二元运算符:

A reduction is inferred automatically if a variable is updated by a binary function/operator using its previous value in the loop body. The initial value of the reduction is inferred automatically for += and *= operators. For other functions/operators, the reduction variable should hold the identity value right before entering the prange loop. Reductions in this manner are supported for scalars and for arrays of arbitrary dimensions.

OP中的代码使用列表作为交叉迭代依赖,并在循环体中调用list.append。就我个人而言,我不会将 list.append 称为缩减,并且它没有使用二元运算符,因此我的假设是它很可能 不受支持 。至于其他交叉迭代依赖项 running:它在前一次迭代的结果上使用加法(这很好),但如果它超过阈值(这可能不好),也会有条件地将其重置为零。

Numba 提供检查中间代码(LLVM 和 ASM)代码的方法:

dynamic_cumsum.inspect_types()
dynamic_cumsum.inspect_llvm()
dynamic_cumsum.inspect_asm()

但是,即使我对结果有必要的理解,可以对发出的代码的正确性做出任何声明 - 一般来说,"prove" 代码正常工作对 "prove" 来说是非常重要的。考虑到我什至缺乏 LLVM 和 ASM 知识,甚至无法查看它是否尝试并行化它,我实际上无法回答您的具体问题它产生的结果。

回到代码,如前所述,如果我使用 parallel=True,它会抛出异常(不支持缩减),所以我假设 numba 在示例中没有并行化任何内容:

from numba import njit, prange

@njit(parallel=True)
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

dynamic_cumsum(np.ones(100), np.arange(100), 10)
AssertionError: Invalid reduction format

During handling of the above exception, another exception occurred:

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Invalid reduction format

File "<>", line 7:
def dynamic_cumsum(seq, index, max_value):
    <source elided>
    running = 0
    for i in prange(len(seq)):
    ^

[1] During: lowering "id=2[LoopNest(index_variable = parfor_index.192, range = (0, seq_size0.189, 1))]{56: <ir.Block at <> (10)>, 24: <ir.Block at <> (7)>, 34: <ir.Block at <> (8)>}Var(parfor_index.192, <> (7))" at <> (7)

所以剩下要说的是:prange 不提供任何速度优势 在这种情况下 超过正常的 range(因为它不在平行线)。因此,在那种情况下,我不会 "risk" 潜在的问题 and/or 使读者感到困惑 - 鉴于 numba 文档不支持它。

from numba import njit, prange

@njit
def p_dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

@njit
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in range(len(seq)):  # <-- here is the only change
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

只是支持我之前所做的 "not faster than" 声明的快速计时:

import numpy as np
seq = np.random.randint(0, 100, 10_000_000)
index = np.arange(10_000_000)
max_ = 500
# Correctness and warm-up
assert p_dynamic_cumsum(seq, index, max_) == dynamic_cumsum(seq, index, max_)
%timeit p_dynamic_cumsum(seq, index, max_)
# 468 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dynamic_cumsum(seq, index, max_)
# 470 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)