从 numpy 数组中解复用值

Demuxing values from numpy array

设备发送一组多路复用错误代码。您可以将多路复用错误代码视为某种 FIFO 环形缓冲区,其长度可以根据同时激活的错误代码的数量而定。我希望将错误代码多路分解为单独的布尔数组。

我正在寻找一种有效的方法(即删除 for 循环)来实现以下代码:

import numpy as np

def get_error_vector(error_mux_vector, error_id, period):
    index = np.where(error_mux_vector == error_id)[0]

    error_vector = np.zeros(np.size(error_mux_vector))

    for i in range(0, np.size(index) - 1):
        if (index[i + 1] - index[i]) <= 1 / period:
            error_vector[index[i]:index[i + 1] + 1] = 1

    return error_vector

这里有模拟值来说明问题。 0 表示没有错误; 1、2、3 是错误代码。我们假设误差信号的频率为 5 Hz(周期为 0.2s):

import matplotlib.pyplot as plt

error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])

error_vector_1 = get_error_vector(error_signal, 1, 0.2)
error_vector_2 = get_error_vector(error_signal, 2, 0.2)
error_vector_3 = get_error_vector(error_signal, 3, 0.2)

plt.plot(error_signal)
plt.plot(error_vector_1)
plt.plot(error_vector_2)
plt.plot(error_vector_3)
plt.legend(['array', 'error 1', 'error 2', 'error 3'])
plt.show()

实际设备数据可能有50k到10M点,大约有100个可能的错误代码。这意味着 for 循环对于用例来说确实效率低下。我想改进这段代码,但到目前为止我还没有找到有效的解决方案。

这是一种一次性创建所有向量的向量化方法。它有两种口味。在我的随机测试用例中,第二个更快,但这可能取决于您信号的确切统计数据。

import numpy as np

# dense strat
def demultiplex(signal,maxdist):
    n = signal.max()
    aux = np.zeros((n,len(signal)+1),np.int16)
    nz = signal.nonzero()[0]
    signal = signal[nz]
    idx = signal.argsort(kind="stable")
    valid = ((nz[idx[1:]]<=nz[idx[:-1]]+maxdist)&
             (signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
    aux[signal[idx[valid]]-1,nz[idx[valid]]] = 1
    aux[signal[idx[valid+1]]-1,nz[idx[valid+1]]+1] -= 1
    out = (aux[:,:-1].cumsum(1) > 0).view(np.int8)
    return out

# sparse strat
def demultiplex2(signal,maxdist):
    n = signal.max()
    m = signal.size
    nz = signal.nonzero()[0]
    signal = signal[nz]
    idx = signal.argsort(kind="stable")
    delta = nz[idx[1:]] - nz[idx[:-1]]
    valid = ((delta<=maxdist)&(signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
    delta = delta[valid]
    nz = nz[idx[valid]]
    nz[1:] -= nz[:-1] + delta[:-1]
    offsets = (delta+1).cumsum()
    x = np.ones(offsets[-1],int)
    x[0] = nz[0]
    x[offsets[:-1]] = nz[1:]
    out = np.zeros((n,m),np.uint8)
    out[(signal[idx[valid]]-1).repeat(delta+1),x.cumsum()] = 1
    return out

# OP
def get_error_vector(error_mux_vector, error_id, period):
    index = np.where(error_mux_vector == error_id)[0]
    error_vector = np.zeros(np.size(error_mux_vector),np.int8)
    for i in range(0, np.size(index) - 1):
        if (index[i + 1] - index[i]) <= 1 / period:
            error_vector[index[i]:index[i + 1] + 1] = 1
    return error_vector


#error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])
error_signal = np.random.randint(0,101,1000000)

import time

t=[]
t.append(time.time())
error_vector_1 = get_error_vector(error_signal, 1, 0.02)
error_vector_2 = get_error_vector(error_signal, 2, 0.02)
error_vector_3 = get_error_vector(error_signal, 3, 0.02)
t.append(time.time())
sol = demultiplex(error_signal,50)
t.append(time.time())
sol2 = demultiplex2(error_signal,50)
t.append(time.time())
print("time per error id [OP, pp, pp2]",np.diff(t)/(3,100,100))
print("results equal",end=" ")
print((error_vector_1==sol[0]).all(),end=" ")
print((error_vector_2==sol[1]).all(),end=" ")
print((error_vector_3==sol[2]).all(),end=" ")
print((error_vector_1==sol2[0]).all(),end=" ")
print((error_vector_2==sol2[1]).all(),end=" ")
print((error_vector_3==sol2[2]).all())

样本运行:

time per error id [OP, pp, pp2] [0.02730425 0.00912964 0.00440736]
results equal True True True True True True

一点解释:

  • 我们使用 argsort 信号来轻松识别那些在足够短的时间内跟在自己后面的错误代码。
  • 我们将误差向量排列在一个堆栈中,这样应该设置的点就可以通过坐标信号[t],t
  • 为了设置时间点的延伸,我们将第一个设置为 1,将最后一个设置为 -1 并形成 cumsum - 为了补救重叠的延伸,我们检查 >0 并将结果布尔值转换为 int