如何检查numpy数组是否相等

How to check numpy arrays are equal

我正在用 numpy 做一些练习,特别是广播,但我卡住了..
有人可以解释一下应该如何使用 assert 吗?

def fill_0(n):
    return np.zeros(n) -1

def fill_1(n):
    return np.zeros(n) *(-1)

def fill_2(n):
    return - np.ones(n)

def fill_3(n):
    return - np.ones(n) -2

assert fill_0(4) == fill_1(4) == fill_2(4) == fill_3(4)

我会这样做:

np.testing.assert_array_equal(fill_0(4), fill_1(4))
np.testing.assert_array_equal(fill_0(4), fill_2(4))
np.testing.assert_array_equal(fill_0(4), fill_3(4))

这使得失败的地方更加清晰(因为每一对都是单独的一行),即使数据中有 NaN 也能正常工作,而常规相等比较会失败(因为 NaN==NaN是假的)。