对 numpy 数组中的元素求平方时的奇怪行为
Weird behavior when squaring elements in numpy array
我有两个形状为 (1, 250000) 的 numpy 数组:
a = [[ 0 254 1 ..., 255 0 1]]
b = [[ 1 0 252 ..., 0 255 255]]
我想创建一个新的 numpy 数组,其元素是数组 a
和 b
中元素平方和的平方根,但我没有得到正确的结果:
>>> c = np.sqrt(np.square(a)+np.square(b))
>>> print c
[[ 1. 2. 4.12310553 ..., 1. 1. 1.41421354]]
我是不是漏掉了一些简单的东西?
对于这个简单的案例,它工作得很好:
In [1]: a = np.array([[ 0, 2, 4, 6, 8]])
In [2]: b = np.array([[ 1, 3, 5, 7, 9]])
In [3]: c = np.sqrt(np.square(a) + np.square(b))
In [4]: print(c)
[[ 1. 3.60555128 6.40312424 9.21954446 12.04159458]]
你肯定做错了。
推测您的数组 a
和 b
是无符号 8 位整数数组——您可以通过检查属性 a.dtype
来检查。当你对它们进行平方时,数据类型被保留,8 位值溢出,这意味着值 "wrap around"(即平方值以 256 为模):
In [7]: a = np.array([[0, 254, 1, 255, 0, 1]], dtype=np.uint8)
In [8]: np.square(a)
Out[8]: array([[0, 4, 1, 1, 0, 1]], dtype=uint8)
In [9]: b = np.array([[1, 0, 252, 0, 255, 255]], dtype=np.uint8)
In [10]: np.square(a) + np.square(b)
Out[10]: array([[ 1, 4, 17, 1, 1, 2]], dtype=uint8)
In [11]: np.sqrt(np.square(a) + np.square(b))
Out[11]:
array([[ 1. , 2. , 4.12310553, 1. , 1. ,
1.41421354]], dtype=float32)
为避免此问题,您可以告诉 np.square
使用浮点数据类型:
In [15]: np.sqrt(np.square(a, dtype=np.float64) + np.square(b, dtype=np.float64))
Out[15]:
array([[ 1. , 254. , 252.00198412, 255. ,
255. , 255.00196078]])
您也可以使用函数 numpy.hypot
,但您可能仍想使用 dtype
参数,否则默认数据类型为 np.float16
:
In [16]: np.hypot(a, b)
Out[16]: array([[ 1., 254., 252., 255., 255., 255.]], dtype=float16)
In [17]: np.hypot(a, b, dtype=np.float64)
Out[17]:
array([[ 1. , 254. , 252.00198412, 255. ,
255. , 255.00196078]])
您可能想知道为什么我在 numpy.square
和 numpy.hypot
中使用的 dtype
参数没有显示在函数的文档字符串中。这两个函数都是numpy "ufuncs", and the authors of numpy decided that it was better to show only the main arguments in the docstring. The optional arguments are documented in the reference manual.
我有两个形状为 (1, 250000) 的 numpy 数组:
a = [[ 0 254 1 ..., 255 0 1]]
b = [[ 1 0 252 ..., 0 255 255]]
我想创建一个新的 numpy 数组,其元素是数组 a
和 b
中元素平方和的平方根,但我没有得到正确的结果:
>>> c = np.sqrt(np.square(a)+np.square(b))
>>> print c
[[ 1. 2. 4.12310553 ..., 1. 1. 1.41421354]]
我是不是漏掉了一些简单的东西?
对于这个简单的案例,它工作得很好:
In [1]: a = np.array([[ 0, 2, 4, 6, 8]])
In [2]: b = np.array([[ 1, 3, 5, 7, 9]])
In [3]: c = np.sqrt(np.square(a) + np.square(b))
In [4]: print(c)
[[ 1. 3.60555128 6.40312424 9.21954446 12.04159458]]
你肯定做错了。
推测您的数组 a
和 b
是无符号 8 位整数数组——您可以通过检查属性 a.dtype
来检查。当你对它们进行平方时,数据类型被保留,8 位值溢出,这意味着值 "wrap around"(即平方值以 256 为模):
In [7]: a = np.array([[0, 254, 1, 255, 0, 1]], dtype=np.uint8)
In [8]: np.square(a)
Out[8]: array([[0, 4, 1, 1, 0, 1]], dtype=uint8)
In [9]: b = np.array([[1, 0, 252, 0, 255, 255]], dtype=np.uint8)
In [10]: np.square(a) + np.square(b)
Out[10]: array([[ 1, 4, 17, 1, 1, 2]], dtype=uint8)
In [11]: np.sqrt(np.square(a) + np.square(b))
Out[11]:
array([[ 1. , 2. , 4.12310553, 1. , 1. ,
1.41421354]], dtype=float32)
为避免此问题,您可以告诉 np.square
使用浮点数据类型:
In [15]: np.sqrt(np.square(a, dtype=np.float64) + np.square(b, dtype=np.float64))
Out[15]:
array([[ 1. , 254. , 252.00198412, 255. ,
255. , 255.00196078]])
您也可以使用函数 numpy.hypot
,但您可能仍想使用 dtype
参数,否则默认数据类型为 np.float16
:
In [16]: np.hypot(a, b)
Out[16]: array([[ 1., 254., 252., 255., 255., 255.]], dtype=float16)
In [17]: np.hypot(a, b, dtype=np.float64)
Out[17]:
array([[ 1. , 254. , 252.00198412, 255. ,
255. , 255.00196078]])
您可能想知道为什么我在 numpy.square
和 numpy.hypot
中使用的 dtype
参数没有显示在函数的文档字符串中。这两个函数都是numpy "ufuncs", and the authors of numpy decided that it was better to show only the main arguments in the docstring. The optional arguments are documented in the reference manual.