从 numpy 数组中解复用值
Demuxing values from numpy array
设备发送一组多路复用错误代码。您可以将多路复用错误代码视为某种 FIFO 环形缓冲区,其长度可以根据同时激活的错误代码的数量而定。我希望将错误代码多路分解为单独的布尔数组。
我正在寻找一种有效的方法(即删除 for 循环)来实现以下代码:
import numpy as np
def get_error_vector(error_mux_vector, error_id, period):
index = np.where(error_mux_vector == error_id)[0]
error_vector = np.zeros(np.size(error_mux_vector))
for i in range(0, np.size(index) - 1):
if (index[i + 1] - index[i]) <= 1 / period:
error_vector[index[i]:index[i + 1] + 1] = 1
return error_vector
这里有模拟值来说明问题。 0 表示没有错误; 1、2、3 是错误代码。我们假设误差信号的频率为 5 Hz(周期为 0.2s):
import matplotlib.pyplot as plt
error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])
error_vector_1 = get_error_vector(error_signal, 1, 0.2)
error_vector_2 = get_error_vector(error_signal, 2, 0.2)
error_vector_3 = get_error_vector(error_signal, 3, 0.2)
plt.plot(error_signal)
plt.plot(error_vector_1)
plt.plot(error_vector_2)
plt.plot(error_vector_3)
plt.legend(['array', 'error 1', 'error 2', 'error 3'])
plt.show()
实际设备数据可能有50k到10M点,大约有100个可能的错误代码。这意味着 for 循环对于用例来说确实效率低下。我想改进这段代码,但到目前为止我还没有找到有效的解决方案。
这是一种一次性创建所有向量的向量化方法。它有两种口味。在我的随机测试用例中,第二个更快,但这可能取决于您信号的确切统计数据。
import numpy as np
# dense strat
def demultiplex(signal,maxdist):
n = signal.max()
aux = np.zeros((n,len(signal)+1),np.int16)
nz = signal.nonzero()[0]
signal = signal[nz]
idx = signal.argsort(kind="stable")
valid = ((nz[idx[1:]]<=nz[idx[:-1]]+maxdist)&
(signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
aux[signal[idx[valid]]-1,nz[idx[valid]]] = 1
aux[signal[idx[valid+1]]-1,nz[idx[valid+1]]+1] -= 1
out = (aux[:,:-1].cumsum(1) > 0).view(np.int8)
return out
# sparse strat
def demultiplex2(signal,maxdist):
n = signal.max()
m = signal.size
nz = signal.nonzero()[0]
signal = signal[nz]
idx = signal.argsort(kind="stable")
delta = nz[idx[1:]] - nz[idx[:-1]]
valid = ((delta<=maxdist)&(signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
delta = delta[valid]
nz = nz[idx[valid]]
nz[1:] -= nz[:-1] + delta[:-1]
offsets = (delta+1).cumsum()
x = np.ones(offsets[-1],int)
x[0] = nz[0]
x[offsets[:-1]] = nz[1:]
out = np.zeros((n,m),np.uint8)
out[(signal[idx[valid]]-1).repeat(delta+1),x.cumsum()] = 1
return out
# OP
def get_error_vector(error_mux_vector, error_id, period):
index = np.where(error_mux_vector == error_id)[0]
error_vector = np.zeros(np.size(error_mux_vector),np.int8)
for i in range(0, np.size(index) - 1):
if (index[i + 1] - index[i]) <= 1 / period:
error_vector[index[i]:index[i + 1] + 1] = 1
return error_vector
#error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])
error_signal = np.random.randint(0,101,1000000)
import time
t=[]
t.append(time.time())
error_vector_1 = get_error_vector(error_signal, 1, 0.02)
error_vector_2 = get_error_vector(error_signal, 2, 0.02)
error_vector_3 = get_error_vector(error_signal, 3, 0.02)
t.append(time.time())
sol = demultiplex(error_signal,50)
t.append(time.time())
sol2 = demultiplex2(error_signal,50)
t.append(time.time())
print("time per error id [OP, pp, pp2]",np.diff(t)/(3,100,100))
print("results equal",end=" ")
print((error_vector_1==sol[0]).all(),end=" ")
print((error_vector_2==sol[1]).all(),end=" ")
print((error_vector_3==sol[2]).all(),end=" ")
print((error_vector_1==sol2[0]).all(),end=" ")
print((error_vector_2==sol2[1]).all(),end=" ")
print((error_vector_3==sol2[2]).all())
样本运行:
time per error id [OP, pp, pp2] [0.02730425 0.00912964 0.00440736]
results equal True True True True True True
一点解释:
- 我们使用 argsort 信号来轻松识别那些在足够短的时间内跟在自己后面的错误代码。
- 我们将误差向量排列在一个堆栈中,这样应该设置的点就可以通过坐标信号[t],t
- 为了设置时间点的延伸,我们将第一个设置为 1,将最后一个设置为 -1 并形成 cumsum - 为了补救重叠的延伸,我们检查 >0 并将结果布尔值转换为 int
设备发送一组多路复用错误代码。您可以将多路复用错误代码视为某种 FIFO 环形缓冲区,其长度可以根据同时激活的错误代码的数量而定。我希望将错误代码多路分解为单独的布尔数组。
我正在寻找一种有效的方法(即删除 for 循环)来实现以下代码:
import numpy as np
def get_error_vector(error_mux_vector, error_id, period):
index = np.where(error_mux_vector == error_id)[0]
error_vector = np.zeros(np.size(error_mux_vector))
for i in range(0, np.size(index) - 1):
if (index[i + 1] - index[i]) <= 1 / period:
error_vector[index[i]:index[i + 1] + 1] = 1
return error_vector
这里有模拟值来说明问题。 0 表示没有错误; 1、2、3 是错误代码。我们假设误差信号的频率为 5 Hz(周期为 0.2s):
import matplotlib.pyplot as plt
error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])
error_vector_1 = get_error_vector(error_signal, 1, 0.2)
error_vector_2 = get_error_vector(error_signal, 2, 0.2)
error_vector_3 = get_error_vector(error_signal, 3, 0.2)
plt.plot(error_signal)
plt.plot(error_vector_1)
plt.plot(error_vector_2)
plt.plot(error_vector_3)
plt.legend(['array', 'error 1', 'error 2', 'error 3'])
plt.show()
实际设备数据可能有50k到10M点,大约有100个可能的错误代码。这意味着 for 循环对于用例来说确实效率低下。我想改进这段代码,但到目前为止我还没有找到有效的解决方案。
这是一种一次性创建所有向量的向量化方法。它有两种口味。在我的随机测试用例中,第二个更快,但这可能取决于您信号的确切统计数据。
import numpy as np
# dense strat
def demultiplex(signal,maxdist):
n = signal.max()
aux = np.zeros((n,len(signal)+1),np.int16)
nz = signal.nonzero()[0]
signal = signal[nz]
idx = signal.argsort(kind="stable")
valid = ((nz[idx[1:]]<=nz[idx[:-1]]+maxdist)&
(signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
aux[signal[idx[valid]]-1,nz[idx[valid]]] = 1
aux[signal[idx[valid+1]]-1,nz[idx[valid+1]]+1] -= 1
out = (aux[:,:-1].cumsum(1) > 0).view(np.int8)
return out
# sparse strat
def demultiplex2(signal,maxdist):
n = signal.max()
m = signal.size
nz = signal.nonzero()[0]
signal = signal[nz]
idx = signal.argsort(kind="stable")
delta = nz[idx[1:]] - nz[idx[:-1]]
valid = ((delta<=maxdist)&(signal[idx[1:]]==signal[idx[:-1]])).nonzero()[0]
delta = delta[valid]
nz = nz[idx[valid]]
nz[1:] -= nz[:-1] + delta[:-1]
offsets = (delta+1).cumsum()
x = np.ones(offsets[-1],int)
x[0] = nz[0]
x[offsets[:-1]] = nz[1:]
out = np.zeros((n,m),np.uint8)
out[(signal[idx[valid]]-1).repeat(delta+1),x.cumsum()] = 1
return out
# OP
def get_error_vector(error_mux_vector, error_id, period):
index = np.where(error_mux_vector == error_id)[0]
error_vector = np.zeros(np.size(error_mux_vector),np.int8)
for i in range(0, np.size(index) - 1):
if (index[i + 1] - index[i]) <= 1 / period:
error_vector[index[i]:index[i + 1] + 1] = 1
return error_vector
#error_signal = np.array([0,0,0,0,0,1,2,3,1,2,3,2,3,2,3,0,0,0,0,0,1,1,1,2,3,1,3,1,3,1,3,0,0,0,0,0,2,0,2,2])
error_signal = np.random.randint(0,101,1000000)
import time
t=[]
t.append(time.time())
error_vector_1 = get_error_vector(error_signal, 1, 0.02)
error_vector_2 = get_error_vector(error_signal, 2, 0.02)
error_vector_3 = get_error_vector(error_signal, 3, 0.02)
t.append(time.time())
sol = demultiplex(error_signal,50)
t.append(time.time())
sol2 = demultiplex2(error_signal,50)
t.append(time.time())
print("time per error id [OP, pp, pp2]",np.diff(t)/(3,100,100))
print("results equal",end=" ")
print((error_vector_1==sol[0]).all(),end=" ")
print((error_vector_2==sol[1]).all(),end=" ")
print((error_vector_3==sol[2]).all(),end=" ")
print((error_vector_1==sol2[0]).all(),end=" ")
print((error_vector_2==sol2[1]).all(),end=" ")
print((error_vector_3==sol2[2]).all())
样本运行:
time per error id [OP, pp, pp2] [0.02730425 0.00912964 0.00440736]
results equal True True True True True True
一点解释:
- 我们使用 argsort 信号来轻松识别那些在足够短的时间内跟在自己后面的错误代码。
- 我们将误差向量排列在一个堆栈中,这样应该设置的点就可以通过坐标信号[t],t
- 为了设置时间点的延伸,我们将第一个设置为 1,将最后一个设置为 -1 并形成 cumsum - 为了补救重叠的延伸,我们检查 >0 并将结果布尔值转换为 int