为什么 numpy 不会在非连续数组上短路?
Why does numpy not short-circuit on non-contiguous arrays?
考虑以下简单测试:
import numpy as np
from timeit import timeit
a = np.random.randint(0,2,1000000,bool)
让我们找到第一个True
的索引
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332
这相当快,因为 numpy
短路。
它也适用于连续的切片,
timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635
但似乎不是在非连续的。我主要是想找到最后一个 True
:
timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
UPDATE: My assumption that the observed slowdown was due to not short circuiting is inaccurate (thanks @Victor Ruiz). Indeed, in the
worst-case scenario of an all False
array
b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
we are still an order of magnitude faster than in the non-contiguous
case. I'm ready to accept Victor's explanation that the actual culprit
is a copy being made (timings of forcing a copy with .copy()
are
suggestive). Afterwards it doesn't really matter anymore whether
short-circuiting happens or not.
但其他步长 != 1 会产生类似的行为。
timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296
问题:为什么在最后两个例子中numpy
不短路UPDATE没有复制?
而且,更重要的是:是否有解决方法,即某种方法可以强制 numpy
短路 UPDATE 而无需制作副本 也在非连续数组上?
问题与使用步长时数组的内存对齐有关。
a[1:-1]
、a[::-1]
被认为在内存中对齐,但 a[::2]
不要:
a = np.random.randint(0,2,1000000,bool)
print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False
这解释了为什么 np.argmax
在 a[::2]
上很慢(来自 ndarrays 上的文档):
Several algorithms in NumPy work on arbitrarily strided arrays. However, some algorithms require single-segment arrays. When an irregularly strided array is passed in to such algorithms, a copy is automatically made.
np.argmax(a[::2])
正在复制数组。因此,如果您执行 timeit(lambda: np.argmax(a[::2]), number=5000)
,您将对数组 a
的 5000 个副本进行计时
执行此操作并比较这两个计时调用的结果:
print(timeit(lambda: np.argmax(a[::2]), number=5000))
b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))
编辑:
深入研究 numpy C 中的源代码,我发现 argmax
函数的下划线实现,PyArray_ArgMax which calls at some point to PyArray_ContiguousFromAny 以确保给定的输入数组在内存中对齐(C 风格)
然后,如果数组的 dtype 是 bool,它委托给 BOOL_argmax 函数。
查看它的代码,似乎短路总是应用。
总结
- 为了避免被
np.argmax
复制,确保输入数组在内存中是连续的
- 当数据类型为布尔值时,始终应用短路。
我对解决这个问题很感兴趣。因此,我提出了下一个解决方案,该解决方案设法避免了由于 np.argmax
:
的内部 ndarray 副本而导致的“a[::-1]
”问题
我创建了一个实现函数 argmax
的小型库,该函数是 np.argmax
的包装器,但当输入参数是一个 stride 值设置为 - 的一维布尔数组时,它提高了性能1:
https://github.com/Vykstorm/numpy-bool-argmax-ext
对于这些情况,它使用低级 C 例程从数组的末尾到开头查找具有最大值 (True
) 的项的索引 k
a
。
然后你可以用 len(a)-k-1
计算 argmax(a[::-1])
低级方法不执行任何内部 ndarray 副本,因为它使用的数组 a
已经是 C 连续的并且在内存中对齐。它还适用短路
编辑:
我扩展了库以提高性能 argmax
在处理不同于 -1 的步幅值(使用一维布尔数组)时也取得了良好的效果:a[::2]
、a[::-3]
、e.t.c .
试一试。
考虑以下简单测试:
import numpy as np
from timeit import timeit
a = np.random.randint(0,2,1000000,bool)
让我们找到第一个True
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332
这相当快,因为 numpy
短路。
它也适用于连续的切片,
timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635
但似乎不是在非连续的。我主要是想找到最后一个 True
:
timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
UPDATE: My assumption that the observed slowdown was due to not short circuiting is inaccurate (thanks @Victor Ruiz). Indeed, in the worst-case scenario of an all
False
array
b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
we are still an order of magnitude faster than in the non-contiguous case. I'm ready to accept Victor's explanation that the actual culprit is a copy being made (timings of forcing a copy with
.copy()
are suggestive). Afterwards it doesn't really matter anymore whether short-circuiting happens or not.
但其他步长 != 1 会产生类似的行为。
timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296
问题:为什么在最后两个例子中numpy
不短路UPDATE没有复制?
而且,更重要的是:是否有解决方法,即某种方法可以强制 numpy
短路 UPDATE 而无需制作副本 也在非连续数组上?
问题与使用步长时数组的内存对齐有关。
a[1:-1]
、a[::-1]
被认为在内存中对齐,但 a[::2]
不要:
a = np.random.randint(0,2,1000000,bool)
print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False
这解释了为什么 np.argmax
在 a[::2]
上很慢(来自 ndarrays 上的文档):
Several algorithms in NumPy work on arbitrarily strided arrays. However, some algorithms require single-segment arrays. When an irregularly strided array is passed in to such algorithms, a copy is automatically made.
np.argmax(a[::2])
正在复制数组。因此,如果您执行 timeit(lambda: np.argmax(a[::2]), number=5000)
,您将对数组 a
执行此操作并比较这两个计时调用的结果:
print(timeit(lambda: np.argmax(a[::2]), number=5000))
b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))
编辑:
深入研究 numpy C 中的源代码,我发现 argmax
函数的下划线实现,PyArray_ArgMax which calls at some point to PyArray_ContiguousFromAny 以确保给定的输入数组在内存中对齐(C 风格)
然后,如果数组的 dtype 是 bool,它委托给 BOOL_argmax 函数。 查看它的代码,似乎短路总是应用。
总结
- 为了避免被
np.argmax
复制,确保输入数组在内存中是连续的 - 当数据类型为布尔值时,始终应用短路。
我对解决这个问题很感兴趣。因此,我提出了下一个解决方案,该解决方案设法避免了由于 np.argmax
:
a[::-1]
”问题
我创建了一个实现函数 argmax
的小型库,该函数是 np.argmax
的包装器,但当输入参数是一个 stride 值设置为 - 的一维布尔数组时,它提高了性能1:
https://github.com/Vykstorm/numpy-bool-argmax-ext
对于这些情况,它使用低级 C 例程从数组的末尾到开头查找具有最大值 (True
) 的项的索引 k
a
。
然后你可以用 len(a)-k-1
argmax(a[::-1])
低级方法不执行任何内部 ndarray 副本,因为它使用的数组 a
已经是 C 连续的并且在内存中对齐。它还适用短路
编辑:
我扩展了库以提高性能 argmax
在处理不同于 -1 的步幅值(使用一维布尔数组)时也取得了良好的效果:a[::2]
、a[::-3]
、e.t.c .
试一试。