使用 numpy.where 防止越界

Using numpy.where to prevent out of bound

我正在尝试根据索引数组查找数组中的值。该索引数组可以包含可能超出范围的索引。在那种情况下,我想 return 一个特定的值(这里是 0)。

(我可以使用 for 循环,但那太慢了。)

所以我这样做:

data = np.arange(1000).reshape(10, 10, 10)
i = np.arange(9).reshape(3, 3)
i[0, 0] = 10
condition = (i[:, 0] < 10) & (i[:, 1] < 10) & (i[:, 2] < 10)
values = np.where(condition, data[i[:, 0], i[:, 1], i[:, 2]], 0)

但是我仍然遇到越界错误:

IndexError: index 10 is out of bounds for axis 0 with size 10

我猜是因为第二个参数没有延迟求值,而是在函数调用之前求值。

在numpy中是否有解决方案可以根据条件访问数组但仍保留顺序?通过保留顺序,我的意思是我不能先过滤数组,因为我可能会在最终结果中丢失顺序。最后,在那个特定的例子中,当索引超出范围时,我仍然希望值数组包含一个 0。所以最终结果将是:

array([ 0, 345, 678])

首先建立索引,然后应用修正以修正值。

shp = np.array(data.shape)
j = i % shp 
res = data[j.T.tolist()]
res[(i >= shp).nonzero()[0]] = 0

print(res)
array([  0, 345, 678])

索引数组的每一列都存储了每个维度的索引。我们可以生成有效掩码(通过边界)并将其中的无效掩码设置为零。即超出范围的情况将由 [0,0,0] 索引,然后让数组由这个修改后的版本索引,最后再次使用掩码来重置无效的,就像这样 -

shp = data.shape
valid_mask = (i < shp).all(1)
i[~valid_mask] = 0
out = np.where(valid_mask,data[tuple(i.T)],0)

在不改变 i 的情况下修改的紧凑版本将是 -

out = np.where(valid_mask,data[tuple(np.where(valid_mask,i.T,0))],0)