NumPy 中有多重排列吗?

Is there multi arange in NumPy?

Numpy 的排列只接受 start/stop/step 的单个标量值。这个功能有多个版本吗?哪个可以接受 start/stop/step 的数组输入?例如。输入二维数组如:

[[1 5 1], # start/stop/step first
 [3 8 2]] # start/stop/step second

应该为输入的每一行(每个 start/stop/step)创建由排列串联组成的数组,上面的输入应该创建一维数组

1 2 3 4 3 5 7

即我们需要设计它接下来要做的功能:

print(np.multi_arange(np.array([[1,5,1],[3,8,2]])))
# prints:
# array([1, 2, 3, 4, 3, 5, 7])

并且此函数应该是高效的(纯 numpy),即非常快速地处理形状为 (10000, 3) 的输入数组,而无需纯 Python 循环。

当然可以创建纯 Python 的循环(或 listcomp)来为每一行创建排列并连接此循环的结果。但是我有很多包含三元组的行 start/stop/step 并且需要高效快速的代码,因此寻找纯 numpy 函数。

为什么我需要它。我需要这个来完成几个任务。其中之一用于索引 - 假设我有一个一维数组 a 并且我需要提取该数组的许多(可能相交的)子范围。如果我有那个多版本的 arange 我会这样做:

values = a[np.multi_arange(starts_stops_steps)]

也许可以使用 numpy 函数的某些组合来创建多排列函数?你能推荐一下吗?

对于提取一维数组子范围(参见上面最后一行代码)而不使用 multi_arange?

创建所有索引的特定情况,也许还有一些更有效的解决方案
In [1]: np.r_[1:5:1, 3:8:2]
Out[1]: array([1, 2, 3, 4, 3, 5, 7])

In [2]: np.hstack((np.arange(1,5,1),np.arange(3,8,2)))
Out[2]: array([1, 2, 3, 4, 3, 5, 7])

r_ 版本很好而且紧凑,但速度不快:

In [3]: timeit np.r_[1:5:1, 3:8:2]
23.9 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [4]: timeit np.hstack((np.arange(1,5,1),np.arange(3,8,2)))
11.2 µs ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

这是一个向量化的cumsum,它考虑了正步长和负步长 -

def multi_arange(a):
    steps = a[:,2]
    lens = ((a[:,1]-a[:,0]) + steps-np.sign(steps))//steps
    b = np.repeat(steps, lens)
    ends = (lens-1)*steps + a[:,0]
    b[0] = a[0,0]
    b[lens[:-1].cumsum()] = a[1:,0] - ends[:-1]
    return b.cumsum()

如果您需要验证有效范围:(start < stop when step > 0)(start > stop when step < 0),请使用 pre-processing 步骤:

a = a[((a[:,1] > a[:,0]) & (a[:,2]>0) | (a[:,1] < a[:,0]) & (a[:,2]<0))]

样本运行-

In [17]: a
Out[17]: 
array([[ 1,  5,  1],
       [ 3,  8,  2],
       [18,  6, -2]])

In [18]: multi_arange(a)
Out[18]: array([ 1,  2,  3,  4,  3,  5,  7, 18, 16, 14, 12, 10,  8])

我刚刚使用 numba 提出了我的解决方案。我仍然更喜欢 numpy-only 解决方案,如果我们找到最好的解决方案不携带沉重的 numba JIT 编译器。

我还在我的代码中测试了@Divakar 解决方案。

下一个代码输出是:

naive_multi_arange 0.76601 sec
arty_multi_arange 0.01801 sec 42.52 speedup
divakar_multi_arange 0.05504 sec 13.92 speedup

意思是我的 numba 解决方案有 42 倍的加速,@Divakar 的 numpy 解决方案有 14 倍的加速。

下一个代码也可以是run online here.

import time, random
import numpy as np, numba

@numba.jit(nopython = True)
def arty_multi_arange(a):
    starts, stops, steps = a[:, 0], a[:, 1], a[:, 2]
    pos = 0
    cnt = np.sum((stops - starts + steps - np.sign(steps)) // steps, dtype = np.int64)
    res = np.zeros((cnt,), dtype = np.int64)
    for i in range(starts.size):
        v, stop, step = starts[i], stops[i], steps[i]
        if step > 0:
            while v < stop:
                res[pos] = v
                pos += 1
                v += step
        elif step < 0:
            while v > stop:
                res[pos] = v
                pos += 1
                v += step
    assert pos == cnt
    return res
    
def divakar_multi_arange(a):
    steps = a[:,2]
    lens = ((a[:,1]-a[:,0]) + steps-np.sign(steps))//steps
    b = np.repeat(steps, lens)
    ends = (lens-1)*steps + a[:,0]
    b[0] = a[0,0]
    b[lens[:-1].cumsum()] = a[1:,0] - ends[:-1]
    return b.cumsum()
    
random.seed(0)
neg_prob = 0.5
N = 100000
minv, maxv, maxstep = -100, 300, 15
steps = [random.randrange(1, maxstep + 1) * ((1, -1)[random.random() < neg_prob]) for i in range(N)]
starts = [random.randrange(minv + 1, maxv) for i in range(N)]
stops = [random.randrange(*(((starts[i] + 1, maxv + 1), (minv, starts[i]))[steps[i] < 0])) for i in range(N)]
joined = np.array([starts, stops, steps], dtype = np.int64).T

tb = time.time()
aref = np.concatenate([np.arange(joined[i, 0], joined[i, 1], joined[i, 2], dtype = np.int64) for i in range(N)])
npt = time.time() - tb
print('naive_multi_arange', round(npt, 5), 'sec')

for func in ['arty_multi_arange', 'divakar_multi_arange']:
    globals()[func](joined)
    tb = time.time()
    a = globals()[func](joined)
    myt = time.time() - tb
    print(func, round(myt, 5), 'sec', round(npt / myt, 2), 'speedup')
    assert a.size == aref.size, (a.size, aref.size)
    assert np.all(a == aref), np.vstack((np.flatnonzero(a != aref)[:5], a[a != aref][:5], aref[a != aref][:5])).T