numpy.astype 与 numba 的非常奇怪的结果
very strange results of numpy.astype with numba
为什么?很奇怪...
在 python 中,如果我们用 numba
测试 np.astype(),下面将打印一些结果为
x: [-6. -5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5.]
x-int: [-6 -5 -4 -3 -2 -1 0 1 2 3 4 5]
@numba.njit
def tt():
nn = 3
x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
print(x)
print(x.astype(np.int32))
但是,如果我将 x 的行更改为 x = np.linspace(0, 8*nn-1, 8*nn)-4*nn
,结果会很奇怪,因为
x: [-12. -11. -10. -9. -8. -7. -6. -5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.]
x-int: [-12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 0 2 3 4 5 6 7 8 9 10 11]
x-int中有两个0
?为什么?
tl;dr: 这是已报告的 Numba 错误。
此问题是由于与浮点舍入相关的 Numba linspace
函数略有不准确。这是一个突出问题的示例:
def tt_classic():
nn = 3
return np.linspace(0, 8*nn-1, 8*nn)-4*nn
@numba.njit
def tt_numba():
nn = 3
return np.linspace(0, 8*nn-1, 8*nn)-4*nn
print(tt_classic()[13])
print(tt_numba()[13])
结果如下:
1.0
0.9999999999999982
如您所见,Numba 实现没有 return 精确值。虽然对于大值无法避免此问题,但对于如此小的值,它可以被视为 bug,因为它们可以 精确表示 (没有任何精度损失)在任何 IEEE-754 平台上。
因此,转换会将浮点数 0.9999999999999982 截断为 0(并且 不是最接近的整数 )。如果您想要 安全转换 (即解决方法),您可以明确告诉 Numpy/Numba 这样做。这是一个例子:
@numba.njit
def tt():
nn = 3
x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
np.round(x, 0, x)
print(x)
print(x.astype(np.int32))
为什么?很奇怪...
在 python 中,如果我们用 numba
测试 np.astype(),下面将打印一些结果为
x: [-6. -5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5.]
x-int: [-6 -5 -4 -3 -2 -1 0 1 2 3 4 5]
@numba.njit
def tt():
nn = 3
x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
print(x)
print(x.astype(np.int32))
但是,如果我将 x 的行更改为 x = np.linspace(0, 8*nn-1, 8*nn)-4*nn
,结果会很奇怪,因为
x: [-12. -11. -10. -9. -8. -7. -6. -5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.]
x-int: [-12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 0 2 3 4 5 6 7 8 9 10 11]
x-int中有两个0
?为什么?
tl;dr: 这是已报告的 Numba 错误。
此问题是由于与浮点舍入相关的 Numba linspace
函数略有不准确。这是一个突出问题的示例:
def tt_classic():
nn = 3
return np.linspace(0, 8*nn-1, 8*nn)-4*nn
@numba.njit
def tt_numba():
nn = 3
return np.linspace(0, 8*nn-1, 8*nn)-4*nn
print(tt_classic()[13])
print(tt_numba()[13])
结果如下:
1.0
0.9999999999999982
如您所见,Numba 实现没有 return 精确值。虽然对于大值无法避免此问题,但对于如此小的值,它可以被视为 bug,因为它们可以 精确表示 (没有任何精度损失)在任何 IEEE-754 平台上。
因此,转换会将浮点数 0.9999999999999982 截断为 0(并且 不是最接近的整数 )。如果您想要 安全转换 (即解决方法),您可以明确告诉 Numpy/Numba 这样做。这是一个例子:
@numba.njit
def tt():
nn = 3
x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
np.round(x, 0, x)
print(x)
print(x.astype(np.int32))