在 Numpy 数组中找到所有成对接近数字的最快方法
Fastest way to find all pairs of close numbers in a Numpy array
假设我有一个 N = 10
随机浮点数的 Numpy 数组:
import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)
out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963
5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]
我想找到所有 unique 对数字,彼此之间的差异不超过公差 tol = 1.
(即绝对差异 <= 1)。具体来说,我想获得所有唯一的索引对。每个接近对的索引都应该排序,所有接近对都应该按第一个索引排序。我设法编写了以下工作代码:
def all_close_pairs(arr, tol=1.):
res = set()
for i, x1 in enumerate(arr):
for j, x2 in enumerate(arr):
if i == j:
continue
if np.isclose(x1, x2, rtol=0., atol=tol):
res.add(tuple(sorted([i, j])))
res = np.array(list(res))
return res[res[:,0].argsort()]
print(all_close_pairs(arr, tol=1.))
out[2]: [[1 5]
[2 4]
[3 7]
[3 9]
[7 9]]
然而,实际上我有一个 N = 1000
数字数组,由于嵌套的 for 循环,我的代码变得非常慢。我相信使用 Numpy 向量化有更有效的方法来做到这一点。有谁知道在 Numpy 中最快的方法吗?
问题是您的代码具有 O(n*n)(二次)复杂性。
为了降低复杂性,您可以先尝试对项目进行排序:
def all_close_pairs(arr, tol=1.):
res = set()
arr = sorted(enumerate(arr), key=lambda x: x[1])
for (idx1, (i, x1)) in enumerate(arr):
for idx2 in range(idx1-1, -1, -1):
j, x2 = arr[idx2]
if not np.isclose(x1, x2, rtol=0., atol=tol):
break
indices = sorted([i, j])
res.add(tuple(indices))
return np.array(sorted(res))
但是,这仅在您的值范围远大于公差时才有效。
您可以使用 2 pointers
策略进一步改进这一点,但总体复杂性将保持不变。
这是一个纯numpy操作的解决方案。在我的机器上看起来相当快,但我不知道我们在寻找什么样的速度。
def all_close_pairs(arr, tol=1.):
N = arr.shape[0]
# get indices in the array to consider using meshgrid
pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
# filter out pairs so we get indices in increasing order
pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
# compare indices in your array for closeness
is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
return pair_coords[is_close, :]
一个有效的解决方案是首先使用index = np.argsort()
对输入值进行排序。然后,您可以使用 arr[index]
生成排序数组,然后在 准线性时间 中迭代接近值,如果对的数量很少 [=40] =]连续数组。如果对的数量很大,那么复杂度是 quadratic 由于生成的对的二次数。由此产生的复杂度为:O(n log n + m)
其中 n
是输入数组的大小,m
是生成的对数。
要找到彼此接近的值,一种有效的方法是使用 Numba 迭代值。事实上,虽然在 Numpy 中这可能是可能的,但由于要比较的值的数量可变,它可能效率不高。这是一个实现:
import numba as nb
@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
res = []
for i in range(arr.size):
val = arr[i]
# Iterate over the close numbers (only once)
for j in range(i+1, arr.size):
# Sadly neither np.isclose or np.abs are implemented in Numba so far
if max(val, arr[j]) - min(val, arr[j]) >= tol:
break
res.append((i, j))
if len(res) == 0: # No pairs: we need to help Numpy to know the shape
return np.empty((0, 2), dtype=np.int32)
return np.array(res, dtype=np.int32)
最后,需要更新索引以引用未排序数组中的索引,而不是已排序数组中的索引。您可以使用 index[result]
.
这是结果代码:
index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])
这是结果(顺序与问题中的顺序不同,但如果需要可以排序):
array([[9, 3],
[9, 7],
[3, 7],
[1, 5],
[4, 2]])
提高算法的复杂度
如果您需要更快的算法,那么您可以使用另一种输出格式:您可以为每个输入值提供接近目标输入值的 min/max 值范围。要查找范围,您可以对排序后的数组使用二进制搜索(请参阅:np.searchsorted
)。生成的算法在 O(n log n)
中运行。但是,您无法获取未排序数组中的索引,因为范围是不连续的。
基准
以下是在我的机器上随机输入 1000 项且容差为 1.0 的性能结果:
Reference implementation: ~17000 ms (x 1)
Angelicos' implementation: 1773 ms (x ~10)
Rivers' implementation: 122 ms (x 139)
Rchome's implementation: 20 ms (x 850)
Chris' implementation: 4.57 ms (x 3720)
This implementation: 0.67 ms (x 25373)
您可以先使用 itertools.combinations 创建组合:
def all_close_pairs(arr, tolerance):
pairs = list(combinations(arr, 2))
indexes = list(combinations(range(len(arr)), 2))
all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <= tolerance]
return all_close_pairs_indexes
现在,对于 N=1000,您只需比较 499500 对而不是 100 万对。
工作原理:
我们首先通过 itertools.combinations 创建对。
然后,我们创建它们的索引列表。
出于速度原因,我们使用列表理解而不是 for 循环。
在这个理解中,我们迭代所有对,使用 enumerate
所以我们可以获得对的索引,我们计算对中数字的绝对差,如果检查如果它小于或等于 tolerance
.
如果绝对差小于或等于tolerance
,我们通过索引列表获取对数的索引,并将它们添加到我们的最终列表中。
有点晚了,但是一个完全麻木的解决方案:
import numpy as np
def close_enough( arr, tol = 1 ):
result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol ), 1))
return np.swapaxes( result, 0, 1 )
展开以解释正在发生的事情
def close_enough( arr, tol = 1 ):
bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
# is_close generates a square array after comparing all elements with all elements.
bool_arr = np.triu( bool_arr, 1 )
# Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal
# and all elements below and to the left.
result = np.where( bool_arr ) # Return the row and column indices for Trues
return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns
N = 1000,arr = 浮点数组
%timeit close_enough( arr, tol = 1 )
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [19]: %timeit all_close_pairs( arr, tol = 1 )
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()
# True
假设我有一个 N = 10
随机浮点数的 Numpy 数组:
import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)
out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963
5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]
我想找到所有 unique 对数字,彼此之间的差异不超过公差 tol = 1.
(即绝对差异 <= 1)。具体来说,我想获得所有唯一的索引对。每个接近对的索引都应该排序,所有接近对都应该按第一个索引排序。我设法编写了以下工作代码:
def all_close_pairs(arr, tol=1.):
res = set()
for i, x1 in enumerate(arr):
for j, x2 in enumerate(arr):
if i == j:
continue
if np.isclose(x1, x2, rtol=0., atol=tol):
res.add(tuple(sorted([i, j])))
res = np.array(list(res))
return res[res[:,0].argsort()]
print(all_close_pairs(arr, tol=1.))
out[2]: [[1 5]
[2 4]
[3 7]
[3 9]
[7 9]]
然而,实际上我有一个 N = 1000
数字数组,由于嵌套的 for 循环,我的代码变得非常慢。我相信使用 Numpy 向量化有更有效的方法来做到这一点。有谁知道在 Numpy 中最快的方法吗?
问题是您的代码具有 O(n*n)(二次)复杂性。 为了降低复杂性,您可以先尝试对项目进行排序:
def all_close_pairs(arr, tol=1.):
res = set()
arr = sorted(enumerate(arr), key=lambda x: x[1])
for (idx1, (i, x1)) in enumerate(arr):
for idx2 in range(idx1-1, -1, -1):
j, x2 = arr[idx2]
if not np.isclose(x1, x2, rtol=0., atol=tol):
break
indices = sorted([i, j])
res.add(tuple(indices))
return np.array(sorted(res))
但是,这仅在您的值范围远大于公差时才有效。
您可以使用 2 pointers
策略进一步改进这一点,但总体复杂性将保持不变。
这是一个纯numpy操作的解决方案。在我的机器上看起来相当快,但我不知道我们在寻找什么样的速度。
def all_close_pairs(arr, tol=1.):
N = arr.shape[0]
# get indices in the array to consider using meshgrid
pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
# filter out pairs so we get indices in increasing order
pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
# compare indices in your array for closeness
is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
return pair_coords[is_close, :]
一个有效的解决方案是首先使用index = np.argsort()
对输入值进行排序。然后,您可以使用 arr[index]
生成排序数组,然后在 准线性时间 中迭代接近值,如果对的数量很少 [=40] =]连续数组。如果对的数量很大,那么复杂度是 quadratic 由于生成的对的二次数。由此产生的复杂度为:O(n log n + m)
其中 n
是输入数组的大小,m
是生成的对数。
要找到彼此接近的值,一种有效的方法是使用 Numba 迭代值。事实上,虽然在 Numpy 中这可能是可能的,但由于要比较的值的数量可变,它可能效率不高。这是一个实现:
import numba as nb
@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
res = []
for i in range(arr.size):
val = arr[i]
# Iterate over the close numbers (only once)
for j in range(i+1, arr.size):
# Sadly neither np.isclose or np.abs are implemented in Numba so far
if max(val, arr[j]) - min(val, arr[j]) >= tol:
break
res.append((i, j))
if len(res) == 0: # No pairs: we need to help Numpy to know the shape
return np.empty((0, 2), dtype=np.int32)
return np.array(res, dtype=np.int32)
最后,需要更新索引以引用未排序数组中的索引,而不是已排序数组中的索引。您可以使用 index[result]
.
这是结果代码:
index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])
这是结果(顺序与问题中的顺序不同,但如果需要可以排序):
array([[9, 3],
[9, 7],
[3, 7],
[1, 5],
[4, 2]])
提高算法的复杂度
如果您需要更快的算法,那么您可以使用另一种输出格式:您可以为每个输入值提供接近目标输入值的 min/max 值范围。要查找范围,您可以对排序后的数组使用二进制搜索(请参阅:np.searchsorted
)。生成的算法在 O(n log n)
中运行。但是,您无法获取未排序数组中的索引,因为范围是不连续的。
基准
以下是在我的机器上随机输入 1000 项且容差为 1.0 的性能结果:
Reference implementation: ~17000 ms (x 1)
Angelicos' implementation: 1773 ms (x ~10)
Rivers' implementation: 122 ms (x 139)
Rchome's implementation: 20 ms (x 850)
Chris' implementation: 4.57 ms (x 3720)
This implementation: 0.67 ms (x 25373)
您可以先使用 itertools.combinations 创建组合:
def all_close_pairs(arr, tolerance):
pairs = list(combinations(arr, 2))
indexes = list(combinations(range(len(arr)), 2))
all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <= tolerance]
return all_close_pairs_indexes
现在,对于 N=1000,您只需比较 499500 对而不是 100 万对。
工作原理:
我们首先通过 itertools.combinations 创建对。
然后,我们创建它们的索引列表。
出于速度原因,我们使用列表理解而不是 for 循环。
在这个理解中,我们迭代所有对,使用
enumerate
所以我们可以获得对的索引,我们计算对中数字的绝对差,如果检查如果它小于或等于tolerance
.如果绝对差小于或等于
tolerance
,我们通过索引列表获取对数的索引,并将它们添加到我们的最终列表中。
有点晚了,但是一个完全麻木的解决方案:
import numpy as np
def close_enough( arr, tol = 1 ):
result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol ), 1))
return np.swapaxes( result, 0, 1 )
展开以解释正在发生的事情
def close_enough( arr, tol = 1 ):
bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
# is_close generates a square array after comparing all elements with all elements.
bool_arr = np.triu( bool_arr, 1 )
# Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal
# and all elements below and to the left.
result = np.where( bool_arr ) # Return the row and column indices for Trues
return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns
N = 1000,arr = 浮点数组
%timeit close_enough( arr, tol = 1 )
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [19]: %timeit all_close_pairs( arr, tol = 1 )
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()
# True