提取包含另一个数组元素的 end-point 数组行的矢量化方法

Vectorized approach to extract the rows of the end-point array that contains the elements of another array

标题说的是什么。我正在寻找一种快速的 pythonic 方法来提取包含另一个数组 v

元素的 end-point 数组 A 的行

我想实现的一个简单例子如下:

输入:

A = [[ 4  9]
     [15 19]
     [20 28]
     [31 37]
     [43 43]]    
v =  [ 0  1  2  3 11 12 13 14 26 29 30 31 43]

因为A是一个end-pint数组,也就是说每一行的第一个元素和第二个元素代表一个区间的开始和结束。因为只有[20 28][31 37][43 43]的区间包含v中的元素(本例26,31 and 43包含在端点数组[=14创建的区间中=]),所需的输出是:

[[20 28]
 [31 37]
 [43 43]]

以下是生成实际输入数组的代码:

import numpy as np
np.random.seed(0)

size = 32000
base_arr = np.arange(size)*10

t1 = np.random.randint(0,6, size)+base_arr
t2 = np.random.randint(5,10, size)+base_arr

A = np.vstack((t1,t2)).T
v = np.sort(np.random.randint(0,10,3*size)+np.repeat(base_arr,3))

提前致谢


编辑:在解释中添加了更多细节

沿三维比较

import numpy as np
a = np.array([[ 4,  9],
              [15, 19],
              [20, 28],
              [31, 37],
              [43, 43]])    
v =  np.array([ 0,  1,  2,  3, 11, 12, 13, 14, 26, 29, 30, 31, 43])
between = np.logical_and(v >= a[:,0,None], v <= a[:,1,None])
print(a[between.any(-1)])

>>>
[[20 28]
 [31 37]
 [43 43]]
>>> 

方法 #1

我们可以使用 np.searchsorted 根据 v 值获取每行开始和结束元素的左右位置索引,并查找 non-matching 的值,这将指示特定行在这些范围内至少有一个元素。因此,我们可以简单地做 -

A[np.searchsorted(v,A[:,0],'left')!=np.searchsorted(v,A[:,1],'right')]

方法 #2

另一种方法是使用 left-positioned 索引索引到 v,然后查看它们是否小于正确的 end-points。因此,它将是 -

idx = np.searchsorted(v,A[:,0],'left')
out = A[(idx<len(v)) & (v[idx.clip(max=len(v)-1)]<=A[:,1])]

请注意,这假定 v 已排序并作为数组输入。如果 v 尚未排序,我们需要对其进行排序,然后将其输入。

我这边更大数据集的时间 -

In [65]: %timeit A[np.searchsorted(v,A[:,0],'left')!=np.searchsorted(v,A[:,1],'right')]
2 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [66]: %%timeit
    ...: idx = np.searchsorted(v,A[:,0],'left')
    ...: out = A[(idx<len(v)) & (v[idx.clip(max=len(v)-1)]<=A[:,1])]
1.32 ms ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我不认为这完全是 Pythonic,但它至少是 O(n)。

def find_bounding_intervals(A, v):
    rows = []
    i = 0
    for row in A:
        while all(v[i] < row):
            i += 1
        if row[0] <= v[i] <= row[1]:
            rows.append(row)
    return np.array(rows)

A = np.array([[ 4,  9],
              [15, 19],
              [20, 28],
              [31, 37],
              [43, 43]])
v =  np.array([ 0,  1,  2,  3, 11, 12, 13, 14, 26, 29, 30, 31, 43])
print(find_bounding_intervals(A, v))

我的 low-end 笔记本电脑在 ~0.28 秒内针对您问题中定义的更大数据制定了解决方案。

from numba import njit
@njit
def find_bounding_intervals(A, v):
    rows_L = []
    rows_R = []

    i = 0
    for row in range(A.shape[0]):
        while v[i] < A[row,0] and v[i] < A[row,1]:
            i += 1
        if A[row,0] <= v[i] <= A[row,1]:
            rows_L.append(A[row,0])
            rows_R.append(A[row,1])
    return np.array([rows_L, rows_R]).T

尽管此实现在技术上不是矢量化函数,但它确实是几乎所有大小 n 中最快的。

我要说清楚算法来自@brentertainer