如何使用 numpy.where?

How to work with numpy.where?

我想找到数组的索引,如 x = np.array([[1, 1, 1], [2, 2, 2]]),其中元素等于 y = np.array([1, 1, 1])。所以我这样做了:

In: np.where(x == y)
Out: (array([0, 0, 0]), array([0, 1, 2]))

正确答案。但我希望只得到索引 0,因为 x 的零元素等于 y

你需要先用(x == y).all(axis=1)减少axis=1的比较结果,即所有元素都相等:

np.where((x == y).all(axis=1))[0]
# array([0])