如何防止意外分配到空的 NumPy 视图

How to prevent accidental assignment into empty NumPy views

考虑以下 Python + 无错误执行的 NumPy 代码:

a = np.array((1, 2, 3))

a[13:17] = 23

使用超出数组限制的切片会截断切片,如果开始和停止超出限制,甚至 returns 会截断空视图。分配给这样的切片只会丢弃输入。

在我的用例中,索引以非常重要的方式计算,并用于操作数组的选定部分。上述行为意味着如果索引计算错误,我可能会默默地跳过部分操作。这可能很难检测到,并可能导致“几乎正确”的结果,即最糟糕的编程错误。

出于这个原因,我想对切片进行严格检查,以便在数组边界之外开始或停止会触发错误。有没有办法在 NumPy 中启用它?

作为附加信息,数组很大并且操作执行得非常频繁,即应该没有性能损失。此外,数组通常是多维的,包括多维切片。

您可以改用 np.put_along_axis,这似乎符合您的需要:

>>> a = np.array((1, 2, 3))
>>> np.put_along_axis(a, indices=np.arange(13, 17), axis=0, values=23)

以上将引发以下错误:

IndexError: index 13 is out of bounds for axis 0 with size 3

参数values可以是标量值或另一个 NumPy 数组。

或更短的形式:

>>> np.put_along_axis(a, np.r_[13:17], 23, 0)

编辑: 或者 np.put 有一个 mode='raise' 选项(默认设置):

np.put(a, ind, v, mode='raise')

  • a: ndarray - Target array.

  • ind: array_like - Target indices, interpreted as integers.

  • v: array_like - Values to place in a at target indices. [...]

  • mode: {'raise', 'wrap', 'clip'} optional - Specifies how out-of-bounds indices will behave.

    • 'raise' – raise an error (default)
    • 'wrap' – wrap around
    • 'clip' – clip to the range

默认行为将是:

>>> np.put(a, np.r_[13:17], 23)

IndexError: index 13 is out of bounds for axis 0 with size 3

mode='clip'同时保持沉默:

 >>> np.put(a, np.r_[13:17], 23, mode='clip')

实现您想要的行为的一种方法是使用范围而不是切片:

a = np.array((1, 2, 3))
a[np.arange(13, 17)] = 23

我认为 NumPy 在这里的行为与纯 Python 列表的行为是一致的,应该是预期的。除了解决方法,显式添加断言对于代码可读性可能更好:

index_1, index_2 = ... # a complex computation
assert index_1 < index_2 and index_2 < a.shape[0]
a[index_1:index_2] = 23

根据您的索引的复杂程度(阅读:预测切片后的形状在背面有多大的痛苦),您可能想直接计算预期的形​​状,然后 reshape 到它。如果实际切片数组的大小不匹配,则会引发错误。开销较小:

import numpy as np
from timeit import timeit


def use_reshape(a,idx,val):
    expected_shape = ((s.stop-s.start-1)//(s.step or 1) + 1 if isinstance(s,slice) else 1 for s in idx)
    a[idx].reshape(*expected_shape)[...] = val

def no_check(a,idx,val):
    a[idx] = val
    
val = 23
idx = np.s_[13:1000:2,14:20]
for f in (no_check,use_reshape):
    a = np.zeros((1000,1000))
    print(f.__name__)
    print(timeit(lambda:f(a,idx,val),number=1000),'ms')
    assert (a[idx] == val).all()
    
# check it works
print("\nThis should raise an exception:\n")
use_reshape(a,np.s_[1000:1001,10],0)

请注意,这是概念验证代码。为了确保安全,您必须检查意外的索引类型、匹配的维数,并且重要的是,检查 select 单个元素的索引。

运行 反正它:

no_check
0.004587646995787509 ms
use_reshape
0.006306983006652445 ms

This should raise an exception:

Traceback (most recent call last):
  File "check.py", line 22, in <module>
    use_reshape(a,np.s_[1000:1001,10],0)
  File "check.py", line 7, in use_reshape
    a[idx].reshape(*expected_shape)[...] = val
ValueError: cannot reshape array of size 0 into shape (1,1)