numpy.median.reduceat 的快速替代方案

Fast alternative for numpy.median.reduceat

关于 ,有没有一种快速的方法来计算具有 不等 的数组的中位数元素数量?

例如:

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67, ... ]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3,    ... ]

然后我想计算每个组的数量和中位数之间的差异(例如,组 0 的中位数是 1.025,所以第一个结果是 1.00 - 1.025 = -0.025)。所以对于上面的数组,结果将显示为:

result = [-0.025, 0.025, 0.05, -0.05, -0.19, 0.29, 0.00, 0.10, -0.10, ...]

由于 np.median.reduceat 还不存在,是否有另一种快速的方法来实现这一点?我的数组将包含数百万行,因此速度至关重要!

可以假定索引是连续且有序的(如果不是,则很容易转换它们)。


性能比较示例数据:

import numpy as np

np.random.seed(0)
rows = 10000
cols = 500
ngroup = 100

# Create random data and groups (unique per column)
data = np.random.rand(rows,cols)
groups = np.random.randint(ngroup, size=(rows,cols)) + 10*np.tile(np.arange(cols),(rows,1))

# Flatten
data = data.ravel()
groups = groups.ravel()

# Sort by group
idx_sort = groups.argsort()
data = data[idx_sort]
groups = groups[idx_sort]

也许你已经这样做了,但如果没有,看看它是否足够快:

median_dict = {i: np.median(data[index == i]) for i in np.unique(index)}
def myFunc(my_dict, a): 
    return my_dict[a]
vect_func = np.vectorize(myFunc)
median_diff = data - vect_func(median_dict, index)
median_diff

输出:

array([-0.025,  0.025,  0.05 , -0.05 , -0.19 ,  0.29 ,  0.   ,  0.1  ,
   -0.1  ])

一种方法是在这里使用 Pandas 纯粹是为了利用 groupby。我稍微夸大了输入大小,以便更好地理解时间(因为创建 DF 会产生开销)。

import numpy as np
import pandas as pd

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3]

data = data * 500
index = np.sort(np.random.randint(0, 30, 4500))

def df_approach(data, index):
    df = pd.DataFrame({'data': data, 'label': index})
    df['median'] = df.groupby('label')['data'].transform('median')
    df['result'] = df['data'] - df['median']

给出以下内容timeit

%timeit df_approach(data, index)
5.38 ms ± 50.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

对于相同的样本量,我得到 为:

%timeit dict_approach(data, index)
8.12 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

但是,如果我们将输入再增加 10 倍,时间将变为:

%timeit df_approach(data, index)
7.72 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit dict_approach(data, index)
30.2 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

然而,以牺牲一些可靠性为代价, 使用纯 numpy 的答案出现在:

%timeit bin_median_subtract(data, index)
573 µs ± 7.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

鉴于新数据集(确实应该在开始时设置):

%timeit df_approach(data, groups)
472 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit bin_median_subtract(data, groups) #
3.02 s ± 31.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dict_approach(data, groups) #
<I gave up after 1 minute>

# jitted (using @numba.njit('f8[:](f8[:], i4[:]') on Windows) from  
%timeit diffmedian_jit(data, groups)
132 ms ± 3.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

这是一种基于 NumPy 的方法,用于获取正 bins/index 值的 binned-median -

def bin_median(a, i):
    sidx = np.lexsort((a,i))

    a = a[sidx]
    i = i[sidx]

    c = np.bincount(i)
    c = c[c!=0]

    s1 = c//2

    e = c.cumsum()
    s1[1:] += e[:-1]

    firstval = a[s1-1]
    secondval = a[s1]
    out = np.where(c%2,secondval,(firstval+secondval)/2.0)
    return out

解决我们减法的具体情况-

def bin_median_subtract(a, i):
    sidx = np.lexsort((a,i))

    c = np.bincount(i)

    valid_mask = c!=0
    c = c[valid_mask]    

    e = c.cumsum()
    s1 = c//2
    s1[1:] += e[:-1]
    ssidx = sidx.argsort()
    starts = c%2+s1-1
    ends = s1

    starts_orgindx = sidx[np.searchsorted(sidx,starts,sorter=ssidx)]
    ends_orgindx  = sidx[np.searchsorted(sidx,ends,sorter=ssidx)]
    val = (a[starts_orgindx] + a[ends_orgindx])/2.
    out = a-np.repeat(val,c)
    return out

如果您真的想要加快您的计算速度,有时您需要编写非惯用的 numpy 代码,而您无法使用原生 numpy。

numba 将您的 python 代码编译为低级 C。由于很多 numpy 本身通常与 C 一样快,如果您的问题不适用,这最终很有用本身到 numpy 的本机矢量化。这是一个例子(我假设索引是连续的和排序的,这也反映在示例数据中):

import numpy as np
import numba

# use the inflated example of roganjosh 
data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3] 

data = np.array(data * 500) # using arrays is important for numba!
index = np.sort(np.random.randint(0, 30, 4500))               

# jit-decorate; original is available as .py_func attribute
@numba.njit('f8[:](f8[:], i8[:])') # explicit signature implies ahead-of-time compile
def diffmedian_jit(data, index): 
    res = np.empty_like(data) 
    i_start = 0 
    for i in range(1, index.size): 
        if index[i] == index[i_start]: 
            continue 

        # here: i is the first _next_ index 
        inds = slice(i_start, i)  # i_start:i slice 
        res[inds] = data[inds] - np.median(data[inds]) 

        i_start = i 

    # also fix last label 
    res[i_start:] = data[i_start:] - np.median(data[i_start:])

    return res

下面是一些使用 IPython 的 %timeit 魔法的时间:

>>> %timeit diffmedian_jit.py_func(data, index)  # non-jitted function
... %timeit diffmedian_jit(data, index)  # jitted function
...
4.27 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
65.2 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

使用问题中更新的示例数据,这些数字(即 python 函数的 运行 时间与 JIT 加速函数的 运行 时间)是

>>> %timeit diffmedian_jit.py_func(data, groups) 
... %timeit diffmedian_jit(data, groups)
2.45 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93.6 ms ± 518 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

这相当于使用加速代码在较小的情况下加速 65 倍,在较大的情况下加速 26 倍(当然,与慢速循环代码相比)。另一个好处是(与使用原生 numpy 的典型矢量化不同)我们不需要额外的内存来达到这个速度,它都是关于优化和编译的低级代码,最终成为 运行.


上面的函数假设numpy int数组默认是int64,实际上在Windows上并不是这样。因此,另一种方法是从对 numba.njit 的调用中删除签名,从而触发适当的即时编译。但这意味着该函数将在第一次执行时被编译,这会干扰计时结果(我们可以手动执行一次函数,使用代表性数据类型,或者只是接受第一次计时执行会慢得多,这应该被忽略)。这正是我试图通过指定签名来阻止的,签名会触发提前编译。

无论如何,在正确的 JIT 情况下,我们需要的装饰器只是

@numba.njit
def diffmedian_jit(...):

请注意,我为 jit 编译函数显示的上述时间仅适用于编译函数后。这要么发生在定义时(使用急切编译,当显式签名传递给 numba.njit 时),要么发生在第一次函数调用期间(使用惰性编译,当没有签名传递给 numba.njit 时)。如果函数只执行一次,那么编译时间也应该考虑到这个方法的速度。通常只有当编译 + 执行的总时间小于未编译的 运行 时间时才值得编译函数(在上述情况下确实如此,其中原生 python 函数非常慢)。这主要发生在您多次调用已编译函数时。

作为 max9111 noted in a comment, one important feature of numba is the cache keywordjit。将 cache=True 传递给 numba.jit 会将编译后的函数存储到磁盘,这样在给定的 python 模块的下一次执行期间,函数将从那里加载而不是重新编译,这又可以节省你运行时间长运行.