在 Numpy 数组中找到所有成对接近数字的最快方法

Fastest way to find all pairs of close numbers in a Numpy array

假设我有一个 N = 10 随机浮点数的 Numpy 数组:

import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)

out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963 
         5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]

我想找到所有 unique 对数字,彼此之间的差异不超过公差 tol = 1.(即绝对差异 <= 1)。具体来说,我想获得所有唯一的索引对。每个接近对的索引都应该排序,所有接近对都应该按第一个索引排序。我设法编写了以下工作代码:

def all_close_pairs(arr, tol=1.):
    res = set()
    for i, x1 in enumerate(arr):
        for j, x2 in enumerate(arr):
            if i == j:
                continue
            if np.isclose(x1, x2, rtol=0., atol=tol):
                res.add(tuple(sorted([i, j])))
    res = np.array(list(res))
    return res[res[:,0].argsort()]

print(all_close_pairs(arr, tol=1.))

out[2]: [[1 5]
         [2 4]
         [3 7]
         [3 9]
         [7 9]]

然而,实际上我有一个 N = 1000 数字数组,由于嵌套的 for 循环,我的代码变得非常慢。我相信使用 Numpy 向量化有更有效的方法来做到这一点。有谁知道在 Numpy 中最快的方法吗?

问题是您的代码具有 O(n*n)(二次)复杂性。 为了降低复杂性,您可以先尝试对项目进行排序:

def all_close_pairs(arr, tol=1.):
    res = set()
    arr = sorted(enumerate(arr), key=lambda x: x[1])
    for (idx1, (i, x1)) in enumerate(arr):
        for idx2 in range(idx1-1, -1, -1):
            j, x2 = arr[idx2]
            if not np.isclose(x1, x2, rtol=0., atol=tol):
                break
            indices = sorted([i, j])
            res.add(tuple(indices))
    return np.array(sorted(res))

但是,这仅在您的值范围远大于公差时才有效。

您可以使用 2 pointers 策略进一步改进这一点,但总体复杂性将保持不变。

这是一个纯numpy操作的解决方案。在我的机器上看起来相当快,但我不知道我们在寻找什么样的速度。

def all_close_pairs(arr, tol=1.):
    N = arr.shape[0]
    # get indices in the array to consider using meshgrid
    pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
    # filter out pairs so we get indices in increasing order
    pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
    # compare indices in your array for closeness
    is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
    return pair_coords[is_close, :]

一个有效的解决方案是首先使用index = np.argsort()对输入值进行排序。然后,您可以使用 arr[index] 生成排序数组,然后在 准线性时间 中迭代接近值,如果对的数量很少 [=40] =]连续数组。如果对的数量很大,那么复杂度是 quadratic 由于生成的对的二次数。由此产生的复杂度为:O(n log n + m) 其中 n 是输入数组的大小,m 是生成的对数。

要找到彼此接近的值,一种有效的方法是使用 Numba 迭代值。事实上,虽然在 Numpy 中这可能是可能的,但由于要比较的值的数量可变,它可能效率不高。这是一个实现:

import numba as nb

@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
    res = []
    for i in range(arr.size):
        val = arr[i]
        # Iterate over the close numbers (only once)
        for j in range(i+1, arr.size):
            # Sadly neither np.isclose or np.abs are implemented in Numba so far
            if max(val, arr[j]) - min(val, arr[j]) >= tol:
                break
            res.append((i, j))
    if len(res) == 0: # No pairs: we need to help Numpy to know the shape
        return np.empty((0, 2), dtype=np.int32)
    return np.array(res, dtype=np.int32)

最后,需要更新索引以引用未排序数组中的索引,而不是已排序数组中的索引。您可以使用 index[result].

这是结果代码:

index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])

这是结果(顺序与问题中的顺序不同,但如果需要可以排序):

array([[9, 3],
       [9, 7],
       [3, 7],
       [1, 5],
       [4, 2]])

提高算法的复杂度

如果您需要更快的算法,那么您可以使用另一种输出格式:您可以为每个输入值提供接近目标输入值的 min/max 值范围。要查找范围,您可以对排序后的数组使用二进制搜索(请参阅:np.searchsorted)。生成的算法在 O(n log n) 中运行。但是,您无法获取未排序数组中的索引,因为范围是不连续的。

基准

以下是在我的机器上随机输入 1000 项且容差为 1.0 的性能结果:

Reference implementation:   ~17000 ms             (x 1)
Angelicos' implementation:    1773 ms           (x ~10)
Rivers' implementation:        122 ms           (x 139)
Rchome's implementation:        20 ms           (x 850)
Chris' implementation:           4.57 ms       (x 3720)
This implementation:             0.67 ms      (x 25373)

您可以先使用 itertools.combinations 创建组合:

def all_close_pairs(arr, tolerance):
    pairs = list(combinations(arr, 2))
    indexes = list(combinations(range(len(arr)), 2))
    all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <=  tolerance]
    return all_close_pairs_indexes

现在,对于 N=1000,您只需比较 499500 对而不是 100 万对。

工作原理:

  • 我们首先通过 itertools.combinations 创建对。

  • 然后,我们创建它们的索引列表。

  • 出于速度原因,我们使用列表理解而不是 for 循环。

  • 在这个理解中,我们迭代所有对,使用 enumerate 所以我们可以获得对的索引,我们计算对中数字的绝对差,如果检查如果它小于或等于 tolerance.

  • 如果绝对差小于或等于tolerance,我们通过索引列表获取对数的索引,并将它们添加到我们的最终列表中。

有点晚了,但是一个完全麻木的解决方案:

import numpy as np

def close_enough( arr, tol = 1 ): 
    result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol ), 1)) 
    return np.swapaxes( result, 0, 1 ) 

展开以解释正在发生的事情

def close_enough( arr, tol = 1 ):
    bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
    # is_close generates a square array after comparing all elements with all elements.  

    bool_arr = np.triu( bool_arr, 1 ) 
    # Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal 
    # and all elements below and to the left.

    result = np.where( bool_arr )  # Return the row and column indices for Trues
    return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns 

N = 1000,arr = 浮点数组

%timeit close_enough( arr, tol = 1 )                                                                              
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [19]: %timeit all_close_pairs( arr, tol = 1 )                                                                           
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()                                            
# True