为什么 numpy 不会在非连续数组上短路?

Why does numpy not short-circuit on non-contiguous arrays?

考虑以下简单测试:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)

让我们找到第一个True

的索引
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332

这相当快,因为​​ numpy 短路。

它也适用于连续的切片,

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635

但似乎不是在非连续的。我主要是想找到最后一个 True:

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168

UPDATE: My assumption that the observed slowdown was due to not short circuiting is inaccurate (thanks @Victor Ruiz). Indeed, in the worst-case scenario of an all False array

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441

we are still an order of magnitude faster than in the non-contiguous case. I'm ready to accept Victor's explanation that the actual culprit is a copy being made (timings of forcing a copy with .copy() are suggestive). Afterwards it doesn't really matter anymore whether short-circuiting happens or not.

但其他步长 != 1 会产生类似的行为。

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296

问题:为什么在最后两个例子中numpy不短路UPDATE没有复制

而且,更重要的是:是否有解决方法,即某种方法可以强制 numpy 短路 UPDATE 而无需制作副本 也在非连续数组上?

问题与使用步长时数组的内存对齐有关。 a[1:-1]a[::-1] 被认为在内存中对齐,但 a[::2] 不要:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False

这解释了为什么 np.argmaxa[::2] 上很慢(来自 ndarrays 上的文档):

Several algorithms in NumPy work on arbitrarily strided arrays. However, some algorithms require single-segment arrays. When an irregularly strided array is passed in to such algorithms, a copy is automatically made.

np.argmax(a[::2]) 正在复制数组。因此,如果您执行 timeit(lambda: np.argmax(a[::2]), number=5000),您将对数组 a

的 5000 个副本进行计时

执行此操作并比较这两个计时调用的结果:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))

编辑: 深入研究 numpy C 中的源代码,我发现 argmax 函数的下划线实现,PyArray_ArgMax which calls at some point to PyArray_ContiguousFromAny 以确保给定的输入数组在内存中对齐(C 风格)

然后,如果数组的 dtype 是 bool,它委托给 BOOL_argmax 函数。 查看它的代码,似乎短路总是应用。

总结

  • 为了避免被np.argmax复制,确保输入数组在内存中是连续的
  • 当数据类型为布尔值时,始终应用短路。

我对解决这个问题很感兴趣。因此,我提出了下一个解决方案,该解决方案设法避免了由于 np.argmax:

的内部 ndarray 副本而导致的“a[::-1]”问题

我创建了一个实现函数 argmax 的小型库,该函数是 np.argmax 的包装器,但当输入参数是一个 stride 值设置为 - 的一维布尔数组时,它提高了性能1:

https://github.com/Vykstorm/numpy-bool-argmax-ext

对于这些情况,它使用低级 C 例程从数组的末尾到开头查找具有最大值 (True) 的项的索引 k a
然后你可以用 len(a)-k-1

计算 argmax(a[::-1])

低级方法不执行任何内部 ndarray 副本,因为它使用的数组 a 已经是 C 连续的并且在内存中对齐。它还适用短路


编辑: 我扩展了库以提高性能 argmax 在处理不同于 -1 的步幅值(使用一维布尔数组)时也取得了良好的效果:a[::2]a[::-3]、e.t.c .

试一试。