numpy.astype 与 numba 的非常奇怪的结果

very strange results of numpy.astype with numba

为什么?很奇怪... 在 python 中,如果我们用 numba 测试 np.astype(),下面将打印一些结果为

x:     [-6. -5. -4. -3. -2. -1.  0.  1.  2.  3.  4.  5.]
x-int: [-6 -5 -4 -3 -2 -1  0  1  2  3  4  5]

@numba.njit
def tt():
    nn = 3
    x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
    print(x)
    print(x.astype(np.int32))

但是,如果我将 x 的行更改为 x = np.linspace(0, 8*nn-1, 8*nn)-4*nn,结果会很奇怪,因为

x: [-12. -11. -10.  -9.  -8.  -7.  -6.  -5.  -4.  -3.  -2.  -1.   0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.]
x-int: [-12 -11 -10  -9  -8  -7  -6  -5  -4  -3  -2  -1   0   0   2   3   4   5   6   7   8   9  10  11]

x-int中有两个0?为什么?

tl;dr: 这是已报告的 Numba 错误。

此问题是由于与浮点舍入相关的 Numba linspace 函数略有不准确。这是一个突出问题的示例:

def tt_classic():
    nn = 3
    return np.linspace(0, 8*nn-1, 8*nn)-4*nn

@numba.njit
def tt_numba():
    nn = 3
    return np.linspace(0, 8*nn-1, 8*nn)-4*nn

print(tt_classic()[13])
print(tt_numba()[13])

结果如下:

1.0
0.9999999999999982

如您所见,Numba 实现没有 return 精确值。虽然对于大值无法避免此问题,但对于如此小的值,它可以被视为 bug,因为它们可以 精确表示 (没有任何精度损失)在任何 IEEE-754 平台上。

因此,转换会将浮点数 0.9999999999999982 截断为 0(并且 不是最接近的整数 )。如果您想要 安全转换 (即解决方法),您可以明确告诉 Numpy/Numba 这样做。这是一个例子:

@numba.njit
def tt():
    nn = 3
    x = np.linspace(0, 4*nn-1, 4*nn)-2*nn
    np.round(x, 0, x)
    print(x)
    print(x.astype(np.int32))

Numba 错误跟踪器 here 上报告了此错误。 您可能还对 this 相关的 Numba 问题感兴趣。