NumPy - 按频率对大型数组进行快速稳定的 arg 排序

NumPy - fast stable arg-sort of large array by frequency

我有任何可比较的 dtype 的大一维 NumPy 数组 a,它的一些元素可能会重复。

我如何找到排序索引 ix 将稳定排序( 中的稳定性)a 按 descending/ascending 订单中值的频率?

我想找到最快最简单的方法来做到这一点。也许现有的标准 numpy 函数可以做到这一点。

还有另一个相关的 但它专门要求删除数组重复项,即仅输出唯一的排序值,我需要原始数组的所有值,包括重复项。

我已经编写了我的第一个试验来完成任务,但它不是最快的(使用 Python 循环)并且可能不是 shortest/simplest 可能的形式。如果相等元素的重复次数不多且数组很大,此 python 循环可能会非常昂贵。如果在 NumPy 中可用(例如虚构的 np.argsort_by_freq())。

Try it online!

import numpy as np
np.random.seed(1)
hi, n, desc = 7, 24, True
a = np.random.choice(np.arange(hi), (n,), p = (
    lambda p = np.random.random((hi,)): p / p.sum()
)())
us, cs = np.unique(a, return_counts = True)
af = np.zeros(n, dtype = np.int64)
for u, c in zip(us, cs):
    af[a == u] = c
if desc:
    ix = np.argsort(-af, kind = 'stable') # Descending sort
else:
    ix = np.argsort(af, kind = 'stable') # Ascending sort
print('rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)')
print('    / sorted_freqs(4) / sorting_ix(5)')
print(np.stack((
    np.arange(n), a, af, a[ix], af[ix], ix,
), 0))

输出:

rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)
    / sorted_freqs(4) / sorting_ix(5)
[[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
 [ 1  1  1  1  3  0  5  0  3  1  1  0  0  4  6  1  3  5  5  0  0  0  5  0]
 [ 7  7  7  7  3  8  4  8  3  7  7  8  8  1  1  7  3  4  4  8  8  8  4  8]
 [ 0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  5  5  5  5  3  3  3  4  6]
 [ 8  8  8  8  8  8  8  8  7  7  7  7  7  7  7  4  4  4  4  3  3  3  1  1]
 [ 5  7 11 12 19 20 21 23  0  1  2  3  9 10 15  6 17 18 22  4  8 16 13 14]]

我可能遗漏了一些东西,但似乎使用 Counter 你可以根据元素值的计数对每个元素的索引进行排序,使用元素值然后使用索引来打破关系。例如:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], v, i) for i, v in enumerate(a)]
t.sort()
print([v[2] for v in t])
t.sort(reverse=True)
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[23, 21, 20, 19, 12, 11, 7, 5, 15, 10, 9, 3, 2, 1, 0, 22, 18, 17, 6, 16, 8, 4, 14, 13]

如果你想保持索引的升序与计数相等的组,你可以只使用 lambda 函数进行降序排序:

t.sort(key = lambda x:(-x[0],-x[1],x[2]))
print([v[2] for v in t])

输出:

[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 14, 13]

如果您想按照元素最初出现在数组中的顺序维护元素的顺序如果它们的计数相同,那么不要对值进行排序,而是对元素进行排序在它们在数组中第一次出现的索引上:

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

idxs = {}
t = []
for i, v in enumerate(a):
    if not v in idxs:
        idxs[v] = i
    t.append((counts[v], idxs[v], i))

t.sort()
print([v[2] for v in t])
t.sort(key = lambda x:(-x[0],x[1],x[2]))
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 13, 14]

按照count排序,然后在数组中定位,根本不需要value和第一个索引:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], i) for i, v in enumerate(a)]
t.sort()
print([v[1] for v in t])
t.sort(key = lambda x:(-x[0],x[1]))
print([v[1] for v in t])

这会为您的字符串数组生成与示例数据的先前代码相同的输出:

a = ['g',  'g',  'c',  'f',  'd',  'd',  'g',  'a',  'a',  'a',  'f',  'f',  'f',
     'g',  'f',  'c',  'f',  'a',  'e',  'b',  'g',  'd',  'c',  'b',  'f' ]

这会产生输出:

[18, 19, 23, 2, 4, 5, 15, 21, 22, 7, 8, 9, 17, 0, 1, 6, 13, 20, 3, 10, 11, 12, 14, 16, 24]
[3, 10, 11, 12, 14, 16, 24, 0, 1, 6, 13, 20, 7, 8, 9, 17, 2, 4, 5, 15, 21, 22, 19, 23, 18]

我只是认为自己可能非常快速地解决任何 dtype,只使用 numpy 函数而不使用 python 循环,它在 O(N log N) 时间内工作。使用的 numpy 函数:np.uniquenp.argsort 和数组索引。

虽然在原始问题中没有被问到,但我实现了额外的标志 equal_order_by_val 如果它是 False 那么具有相同频率的数组元素被排序为相等的稳定范围,这意味着可能 c d d c d c输出类似于下面的输出转储,因为这是元素进入原始数组的相同频率的顺序。当 flag 为 True 时,这些元素还按原始数组的值排序,结果为 c c c d d d。换句话说,在 False 的情况下,我们仅按键 freq 进行稳定排序,而当它为 True 时,我们按 (freq, value) 进行升序排序,按 (-freq, value) 进行降序排序。

Try it online!

import string, math
import numpy as np
np.random.seed(0)

# Generating input data

hi, n, desc = 7, 25, True
letters = np.array(list(string.ascii_letters), dtype = np.object_)[:hi]
a = np.random.choice(letters, (n,), p = (
    lambda p = np.random.random((letters.size,)): p / p.sum()
)())

for equal_order_by_val in [False, True]:
    # Solving task

    us, ui, cs = np.unique(a, return_inverse = True, return_counts = True)
    af = cs[ui]
    sort_key = -af if desc else af
    if equal_order_by_val:
        shift_bits = max(1, math.ceil(math.log(us.size) / math.log(2)))
        sort_key = ((sort_key.astype(np.int64) << shift_bits) +
            np.arange(us.size, dtype = np.int64)[ui])
    ix = np.argsort(sort_key, kind = 'stable') # Do sorting itself

    # Printing results

    print('\nequal_order_by_val:', equal_order_by_val)
    for name, val in [
        ('i_col', np.arange(n)),  ('original_a', a),
        ('freqs', af),            ('sorted_a', a[ix]),
        ('sorted_freqs', af[ix]), ('sorting_ix', ix),
    ]:
        print(name.rjust(12), ' '.join([str(e).rjust(2) for e in val]))

输出:

equal_order_by_val: False
       i_col  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
  original_a  g  g  c  f  d  d  g  a  a  a  f  f  f  g  f  c  f  a  e  b  g  d  c  b  f
       freqs  5  5  3  7  3  3  5  4  4  4  7  7  7  5  7  3  7  4  1  2  5  3  3  2  7
    sorted_a  f  f  f  f  f  f  f  g  g  g  g  g  a  a  a  a  c  d  d  c  d  c  b  b  e
sorted_freqs  7  7  7  7  7  7  7  5  5  5  5  5  4  4  4  4  3  3  3  3  3  3  2  2  1
  sorting_ix  3 10 11 12 14 16 24  0  1  6 13 20  7  8  9 17  2  4  5 15 21 22 19 23 18

equal_order_by_val: True
       i_col  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
  original_a  g  g  c  f  d  d  g  a  a  a  f  f  f  g  f  c  f  a  e  b  g  d  c  b  f
       freqs  5  5  3  7  3  3  5  4  4  4  7  7  7  5  7  3  7  4  1  2  5  3  3  2  7
    sorted_a  f  f  f  f  f  f  f  g  g  g  g  g  a  a  a  a  c  c  c  d  d  d  b  b  e
sorted_freqs  7  7  7  7  7  7  7  5  5  5  5  5  4  4  4  4  3  3  3  3  3  3  2  2  1
  sorting_ix  3 10 11 12 14 16 24  0  1  6 13 20  7  8  9 17  2 15 22  4  5 21 19 23 18