将此字符串解析为 numpy 数组的最快方法

Fastest way to parse this string to a numpy array

我们必须执行以下操作大约 400,000 次,因此我正在寻找最有效的解决方案。我尝试了几种方法,但我很好奇是否有更好的方法:)


数据示例

我们可以使用下面的代码来生成一个示例测试集
random.seed(10)
np.random.seed(10)
def test_str():
    n = 10000000
    arr  = np.random.randint(10000, size=n)
    sign = np.random.choice(['+','-'], size=n)
    return 'ID1' + '\t' + ' '.join(["{}{}".format(a,b) for a,b in zip(arr, sign)])

看起来像 ID1\t7688+ 737+ 677+ 1508- 9251-......

代码的全部内容:)

google colab 复制代码(P.s。运行 它给了我一个 TypingError 而它 运行 在我的机器上很好),或者看看下面的函数

一般函数
从这个 Numba issue ,但基于 @armamut 的回答,这可能会给 Numba 带来很多开销,使原生 Numpy 显然更快..

@nb.jit(nopython=True)
    def str_to_int(s):
        final_index, result = len(s) - 1, 0
        for i,v in enumerate(s):
            result += (ord(v) - 48) * (10 ** (final_index - i))
        return result

方法一

@nb.jit(nopython=True)
def process_number(numb, identifier, i):
    sign = 1 if numb[-1] == '+' else -1
    return str_to_int(numb[:-1]), sign, i, identifier
    
@nb.jit(nopython=True)
def expand1(data):
    identifier, l = data.split('\t')
    identifier = str_to_int(identifier[-1])
    numbers = l.split()
    # init emtpy numpy array
    arr = np.empty(shape = (len(numbers), 4), dtype = np.int64)
    # Fill array    
    for i, numb in enumerate(numbers):
        arr[i,:] = process_number(numb, identifier, i)
    return arr

方法二

@nb.jit(nopython=True)
def expand2(data):
    identifier, l = data.split('\t')
    
    identifier = str_to_int(identifier[-1])
    numbers = l.split()
    size = len(numbers)
    
    numbs = [ str_to_int(numb[:-1]) for numb in numbers ]
    signs = [ 1 if numb[:-1] =='+' else -1 for numb in numbers ]
    
    arr = np.empty(shape = (size, 4), dtype = np.int64)
    arr[:,0] = numbs
    arr[:,1] = signs
    arr[:,2] = np.arange(0, size)
    arr[:,3] = np.repeat(identifier, size)
    return arr

方法 3

@nb.jit(nopython=True)
def expand3(data):
    identifier, l = data.split('\t')
    identifier = str_to_int(identifier[-1])
    numbers = l.split()
    arr = np.empty(shape = (len(numbers), 4), dtype = np.int64)
    for i, numb in enumerate(numbers):
        arr[i,:] = str_to_int(numb[:-1]), 1 if numb[:-1] =='+' else -1, i, identifier
    return arr

回答方法

def expand4(t):
    identifier, l = t.split('\t')
    identifier = np.int(identifier[-1])
    numbers = np.array([np.int(k[:-1]) for k in l.split(' ')])
    signs = np.array([(k[-1] == '+') for k in l.split(' ')]) * 2 - 1

    N = len(numbers)
    arr = np.empty(shape = (N, 4), dtype = np.int64)
    arr[:, 0] = numbers
    arr[:, 1] = signs
    arr[:, 2] = identifier
    arr[:, 3] = np.arange(N)
    return arr

测试结果:

Expand 1
72.7 ms ± 177 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
Expand 2
27.9 ms ± 67.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
Expand 3
8.81 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
Expand 4 ANSWER 1
429 µs ± 63.4 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)

我无法复制您的代码,因为我也遇到了 numba 的“ord”未实现错误。

但是你为什么要使用 numba?您的 str_to_int 操作似乎非常昂贵且未针对矢量操作等进行优化。为什么不(没有 numba):

def expand(t):
    identifier, l = t.split('\t')
    identifier = np.int(identifier[-1])
    numbers = np.array([np.int(k[:-1]) for k in l.split(' ')])
    signs = np.array([(k[-1] == '+') for k in l.split(' ')]) * 2 - 1

    N = len(numbers)
    arr = np.empty(shape = (N, 4), dtype = np.int64)
    arr[:, 0] = numbers
    arr[:, 1] = signs
    arr[:, 2] = identifier
    arr[:, 3] = np.arange(N)
    return arr

t = test_str()
%timeit expand(t)

>>>

1.01 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)