Numpy `where` 子句的奇怪行为

Strange behavior of Numpy `where` clause

我看到 Numpy 1.15.3ufunc where 子句有一个奇怪的行为。

In [1]: import numpy as np

In [2]: x = np.array([[1,2],[3,4]])

In [3]: y = np.ones(x.shape) * 2

In [4]: print(x, "\n", y)
[[1 2]
 [3 4]]
 [[2. 2.]
 [2. 2.]]

In [5]: np.add(x, y, where=x==3)
Out[5]:
array([[2., 2.],     #<=========== where do these 2s come from???
       [5., 2.]])

In [6]: np.add(x, y, where=x==3, out=np.zeros(x.shape))
Out[6]:
array([[0., 0.],
       [5., 0.]])

In [7]: np.add(x, y, where=x==3, out=np.ones(x.shape))
Out[7]:
array([[1., 1.],
       [5., 1.]])

In [8]: np.add(x, y, where=x==3)
Out[8]:
array([[1., 1.], # <========= it seems these 1s are remembered from last computation.
       [5., 1.]])

ADD1

看来我只能用out参数才能得到有理数的结果。

下面没有 out 参数:

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x = np.linspace(-2,2,60)
y = np.linspace(-2,2,60)

xx, yy = np.meshgrid(x,y)

r= np.ones((60,60), dtype=float) * 2
z = np.sqrt(r**2 - xx**2 - yy**2, where=(r**2 - xx**2 - yy**2)>=0) # <==== HERE!!

surf = ax.plot_surface(xx, yy, z, cmap="viridis")

这会生成一个荒谬的图像:

如果我如下添加 out 参数,一切正常。

z = np.zeros(xx.shape)
np.sqrt(r**2 - xx**2 - yy**2, where=(r**2 - xx**2 - yy**2)>=0, out=z)

由于使用 where,您最终会在输出中得到垃圾数据。正如您所说,解决方法是初始化您自己的输出并将其传递给 out.

来自docs about the out arg:

If ‘out’ is None (the default), a uninitialized return array is created. The output array is then filled with the results of the ufunc in the places that the broadcast ‘where’ is True. If ‘where’ is the scalar True (the default), then this corresponds to the entire output being filled. Note that outputs not explicitly filled are left with their uninitialized values.

因此,您跳过的 out 的值(即 whereFalse 的索引)将保留之前的任何值。这就是为什么它看起来 numpy 是以前计算的 "remebering" 值,例如第一个示例代码块末尾的 1s。

正如@WarrenWeckesser 在他的评论中指出的那样,这也意味着当 out 留空时,同一内存块将被重新用于输出,至少在某些情况下是这样。有趣的是,您可以通过将每个输出分配给一个变量来改变您得到的结果:

x = np.array([[1,2],[3,4]])
y = np.ones(x.shape) * 2

arr0 = np.add(x, y, where=x==3)
arr1 = np.add(x, y, where=x==3, out=np.zeros(x.shape))
arr2 = np.add(x, y, where=x==3, out=np.ones(x.shape))
arr3 = np.add(x, y, where=x==3)
print(arr3)

现在你可以清楚地看到输出中的垃圾数据:

[[-2.68156159e+154 -2.68156159e+154]
 [ 5.00000000e+000  2.82470645e-309]]