如何索引 numpy 数组给定轴上每次出现的最大值?

How can I index each occurrence of a max value along a given axis of a numpy array?

假设我有以下 numpy 数组。

Q = np.array([[0,1,1],[1,0,1],[0,2,0]) 

问题:如何确定每个最大值沿轴 1 的位置?所以所需的输出将类似于:

array([[1,2],[0,2],[1]]) # The dtype of the output is not required to be a np array.

使用 np.argmax 我可以确定沿轴第一次出现的最大值,但不能确定后续值。

In: np.argmax(Q, axis =1) 
Out: array([1, 0, 1])    

我还看到依赖于使用 np.argwhere 的答案,这些答案使用了这样的术语。

np.argwhere(Q == np.amax(Q)) 

这在这里也不起作用,因为我无法限制 argwhere 沿单个轴工作。我也不能只是将 np 数组展平到一个轴上,因为每行中的最大值会有所不同。我需要识别每行最大值的每个实例。

是否有一种 pythonic 方法可以实现这一点而无需遍历整个数组的每一行,或者是否有类似于 np.argwhere 的函数接受轴参数?

任何见解将不胜感激!

试试 np.where:

np.where(Q == Q.max(axis=1)[:,None])

输出:

(array([0, 0, 1, 1, 2]), array([1, 2, 0, 2, 1]))

不完全是您想要的输出,但包含等效信息。

您还可以使用 np.argwhere 为您提供 zip 数据:

np.argwhere(Q==Q.max(axis=1)[:,None])

输出:

array([[0, 1],
       [0, 2],
       [1, 0],
       [1, 2],
       [2, 1]])