有效地计算唯一元素的数量 - NumPy / Python

Efficiently counting number of unique elements - NumPy / Python

当运行 np.unique()时,它首先展平数组,对数组进行排序,然后找到唯一值。当我有形状为 (10, 3000, 3000) 的数组时,找到唯一元素大约需要一秒钟,但这很快就会加起来,因为我需要多次调用 np.unique() 。由于我只关心数组中唯一数字的总数,因此排序似乎是在浪费时间。

除了 np.unique() 之外,是否有更快的方法来查找大型数组中唯一值的总数?

这是一个适用于数据类型为 np.uint8 的数组的方法,它比 np.unique.

更快

首先,创建一个要使用的数组:

In [128]: a = np.random.randint(1, 128, size=(10, 3000, 3000)).astype(np.uint8)

为了稍后比较,使用 np.unique:

查找唯一值
In [129]: u = np.unique(a)

这是更快的方法; v 将包含结果:

In [130]: q = np.zeros(256, dtype=int)

In [131]: q[a.ravel()] = 1

In [132]: v = np.nonzero(q)[0]

验证我们得到了相同的结果:

In [133]: np.array_equal(u, v)
Out[133]: True

时间:

In [134]: %timeit u = np.unique(a)
2.86 s ± 9.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [135]: %timeit q = np.zeros(256, dtype=int); q[a.ravel()] = 1; v = np.nonzero(q)
300 ms ± 5.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此 np.unique() 为 2.86 秒,替代方法为 0.3 秒。

我们可以利用元素被限制在 uint8 范围内的事实,通过使用 np.bincount 进行分箱计数,然后简单地计算其中非零的数量。因为,np.bincount 需要一个 1D 数组,我们将用 np.ravel() 展平输入,然后将其提供给 bincount.

因此,实施将是 -

(np.bincount(a.ravel())!=0).sum()

运行时测试

创建具有不同数量唯一数字的输入数组的辅助函数 -

def create_input(n_unique):
    unq_nums = np.random.choice(np.arange(256), n_unique,replace=0)
    return np.random.choice(unq_nums, (10,3000,3000)).astype(np.uint8)

其他方法:

# @Warren Weckesser's soln
def assign_method(a):
    q = np.zeros(256, dtype=int)
    q[a.ravel()] = 1
    return len(np.nonzero(q)[0])

提议方法的验证-

In [141]: a = create_input(n_unique=120)

In [142]: len(np.unique(a))
Out[142]: 120

In [143]: (np.bincount(a.ravel())!=0).sum()
Out[143]: 120

计时 -

In [124]: a = create_input(n_unique=128)

In [125]: %timeit len(np.unique(a)) # Original soln
     ...: %timeit assign_method(a)  # @Warren Weckesser's soln
     ...: %timeit (np.bincount(a.ravel())!=0).sum()
     ...: 
1 loop, best of 3: 3.09 s per loop
1 loop, best of 3: 394 ms per loop
1 loop, best of 3: 209 ms per loop

In [126]: a = create_input(n_unique=256)

In [127]: %timeit len(np.unique(a)) # Original soln
     ...: %timeit assign_method(a)  # @Warren Weckesser's soln
     ...: %timeit (np.bincount(a.ravel())!=0).sum()
     ...: 
1 loop, best of 3: 3.46 s per loop
1 loop, best of 3: 378 ms per loop
1 loop, best of 3: 212 ms per loop

如果您不介意使用 Numba 来 JIT 编译您的代码,并修改您的代码以使 Numba 更容易施展它的魔力,那么有可能对已经提出的建议进行一些很好的改进在其他答案中列出。

使用@Divakar 的post命名:

from numba import jit
import numpy as np

def create_input(n_unique):
    unq_nums = np.random.choice(np.arange(256), n_unique, replace=0)
    return np.random.choice(unq_nums, (10, 3000, 3000)).astype(np.uint8)

def unique(a):
    return len(np.unique(a))

def assign_method(a):
    q = np.zeros(256, dtype=int)
    q[a.ravel()] = 1
    return len(np.nonzero(q)[0])

def bincount(a):
    return (np.bincount(a.ravel())!=0).sum()

def numba_friendly(a):
    q = np.zeros(256, dtype=int)
    count = 0
    for x in a.ravel():
        if q[x] == 0:
            q[x] = 1
            count += 1
    return count

unique_jit = jit(unique)
assign_method_jit = jit(assign_method)
bincount_jit = jit(bincount)
numba_friendly_jit = jit(numba_friendly)

基准:

a = create_input(n_unique=128)

%timeit unique(a)
%timeit unique_jit(a)
%timeit assign_method(a)
%timeit assign_method_jit(a)
%timeit bincount(a)
%timeit bincount_jit(a)
%timeit numba_friendly_jit(a)

结果:

unique:               7.5 s ± 1.14 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
unique_jit:          13.4 s ± 2.03 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
assign_method:       388 ms ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
assign_method_jit:   341 ms ± 27.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
bincount:            2.71 s ± 218 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
bincount_jit:        138 ms ± 40.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numba_friendly_jit: 56.4 ms ± 8.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)