NumPy 数组下三角区域中 n 个最大值的索引

Indices of n largest values in lower triangular region of a NumPy array

我有一个 numpy 余弦相似度矩阵。我想找到 n 个最大值的索引,但排除对角线上的 [​​=12=],并且只针对其中的下三角区域。

similarities = [[ 1.          0.18898224  0.16903085]
 [ 0.18898224  1.          0.67082039]
 [ 0.16903085  0.67082039  1.        ]]

在这种情况下,如果我想要两个最高值,我希望它是 return [1, 0][2, 1].

我已经尝试使用 argpartition 但这不是 return 我正在寻找的

n_select = 1
most_similar = (-similarities).argpartition(n_select, axis=None)[:n_select]

如何得到除对角线1之外的n个最高值,并排除上三角元素?

方法 #1

一种方法 np.tril_indices -

def n_largest_indices_tril(a, n=2):
    m = a.shape[0]
    r,c = np.tril_indices(m,-1)
    idx = a[r,c].argpartition(-n)[-n:]
    return zip(r[idx], c[idx])

样本运行-

In [39]: a
Out[39]: 
array([[ 1.  ,  0.4 ,  0.59,  0.15,  0.29],
       [ 0.4 ,  1.  ,  0.03,  0.57,  0.57],
       [ 0.59,  0.03,  1.  ,  0.9 ,  0.52],
       [ 0.15,  0.57,  0.9 ,  1.  ,  0.37],
       [ 0.29,  0.57,  0.52,  0.37,  1.  ]])

In [40]: n_largest_indices_tril(a, n=2)
Out[40]: [(2, 0), (3, 2)]

In [41]: n_largest_indices_tril(a, n=3)
Out[41]: [(4, 1), (2, 0), (3, 2)]

方法 #2

为了性能,我们可能希望避免生成所有下三角索引,而是使用掩码,为我们提供第二种方法来解决我们的问题,就像这样 -

def n_largest_indices_tril_v2(a, n=2):
    m = a.shape[0]
    r = np.arange(m)
    mask = r[:,None] > r
    idx = a[mask].argpartition(-n)[-n:]

    clens = np.arange(m).cumsum()    
    grp_start = clens[:-1]
    grp_stop = clens[1:]-1    

    rows = np.searchsorted(grp_stop, idx)+1    
    cols  = idx - grp_start[rows-1]
    return zip(rows, cols)

运行时测试

In [143]: # Setup symmetric array 
     ...: N = 1000
     ...: a = np.random.rand(N,N)*0.9
     ...: np.fill_diagonal(a,1)
     ...: m = a.shape[0]
     ...: r,c = np.tril_indices(m,-1)
     ...: a[r,c] = a[c,r]

In [144]: %timeit n_largest_indices_tril(a, n=2)
100 loops, best of 3: 12.5 ms per loop

In [145]: %timeit n_largest_indices_tril_v2(a, n=2)
100 loops, best of 3: 7.85 ms per loop

对于 n 最小的指数

要获得 n 最小的方法,只需使用 ndarray.argpartition(n)[:n] 代替这两种方法。

请记住,方阵的对角线元素具有唯一性 属性:i+j=n,其中 n 是矩阵维数。 然后,您可以只找到数组的 n + 个(对角线元素)最大元素,然后遍历它们并排除元组 (i,j),其中 i+j=n。 希望对您有所帮助!