Numba:如何将任意逻辑字符串解析为循环中的 jitclassed 实例序列

Numba: how to parse arbitrary logic string into sequence of jitclassed instances in a loop

Tl 博士。如果我要简短地解释这个问题:

  1. 我有信号:
np.random.seed(42)
x = np.random.randn(1000)
y = np.random.randn(1000)
z = np.random.randn(1000)
  1. 和人类可读的字符串元组逻辑,如:
entry_sig_ = ((x,y,'crossup',False),)
exit_sig_ = ((x,z,'crossup',False), 'or_',(x,y,'crossdown',False))

其中:

  1. 输出通过以下方式生成:
@njit
def run(x, entry_sig, exit_sig):
    '''
    x: np.array
    entry_sig, exit_sig: homogeneous tuples of tuple signals
    Returns: sequence of 0 and 1 satisfying entry and exit sigs
    ''' 
    L = x.shape[0]
    out = np.empty(L)
    out[0] = 0.0
    out[-1] = 0.0
    i = 1
    trade = True
    while i < L-1:
        out[i] = 0.0
        if reduce_sig(entry_sig,i) and i<L-1:
            out[i] = 1.0
            trade = True
            while trade and i<L-2:
                i += 1
                out[i] = 1.0
                if reduce_sig(exit_sig,i):
                    trade = False
        i+= 1
    return out

reduce_sig(sig,i) 是一个函数(见下面的定义),它解析元组和 returns 给定时间点的结果输出。

问题:

截至目前,SingleSig class 的对象在任何给定时间点从头开始在 for 循环中实例化;因此,没有“记忆”,这完全抵消了 class 的优点,一个简单的功能就可以了。是否存在解决方法(不同的 class 模板、不同的方法等)以便:

  1. 可以查询组合元组信号在特定时间点的值i
  2. “内存”可以重置;即例如MultiSig(sig_tuple).memory_field 可以在成分信号级别设置为 0

以下代码向信号添加内存,可以使用 MultiSig.reset() 将所有信号的计数重置为 0 来擦除内存。可以使用 MultiSig.query_memory(key) 到 return 查询内存当时该信号的命中次数。

为了让记忆功能发挥作用,我必须在信号中添加唯一键以识别它们。

from numba import njit, int64, float64, types
from numba.types import Array, string, boolean
from numba import jitclass
import numpy as np

np.random.seed(42)
x = np.random.randn(1000000)
y = np.random.randn(1000000)
z = np.random.randn(1000000)

# Example of "human-readable" signals
entry_sig_ = ((x,y,'crossup',False),)
exit_sig_ = ((x,z,'crossup',False), 'or_',(x,y,'crossdown',False))

# Turn signals into homogeneous tuple
#entry_sig_
entry_sig = (((x,y,'crossup',False),'NOP','1'),)
#exit_sig_
exit_sig = (((x,z,'crossup',False),'or_','2'),((x,y,'crossdown',False),'NOP','3'))

@njit
def cross(x, y, i):
    '''
    x,y: np.array
    i: int - point in time
    Returns: 1 or 0 when condition is met
    '''
    if (x[i - 1] - y[i - 1])*(x[i] - y[i]) < 0:
        out = 1
    else:
        out = 0
    return out


kv_ty = (types.string,types.int64)

spec = [
    ('memory', types.DictType(*kv_ty)),
]

@njit
def single_signal(x, y, how, acc, i):
    '''
    i: int - point in time
    Returns either signal or accumulator
    '''
    if cross(x, y, i):
        if x[i] < y[i] and how == 'crossdown':
            out = 1
        elif x[i] > y[i] and how == "crossup":
            out = 1
        else:
            out = 0
    else:
        out = 0
    return out
    
@jitclass(spec)
class MultiSig:
    def __init__(self,entry,exit):
        '''
        initialize memory at single signal level
        '''
        memory_dict = {}
        for i in entry:
            memory_dict[str(i[2])] = 0
        
        for i in exit:
            memory_dict[str(i[2])] = 0
        
        self.memory = memory_dict
        
    def reduce_sig(self, sig, i):
        '''
        Parses multisignal
        sig: homogeneous tuple of tuples ("human-readable" signal definition)
        i: int - point in time
        Returns: resulting value of multisignal
        '''
        L = len(sig)
        out = single_signal(*sig[0][0],i)
        logic = sig[0][1]
        if out:
            self.update_memory(sig[0][2])
        for cnt in range(1, L):
            s = single_signal(*sig[cnt][0],i)
            if s:
                self.update_memory(sig[cnt][2])
            out = out | s if logic == 'or_' else out & s
            logic = sig[cnt][1]
        return out
    
    def update_memory(self, key):
        '''
        update memory
        '''
        self.memory[str(key)] += 1
    
    def reset(self):
        '''
        reset memory
        '''
        dicti = {}
        for i in self.memory:
            dicti[i] = 0
        self.memory = dicti
        
    def query_memory(self, key):
        '''
        return number of hits on signal
        '''
        return self.memory[str(key)]

@njit
def run(x, entry_sig, exit_sig):
    '''
    x: np.array
    entry_sig, exit_sig: homogeneous tuples of tuples
    Returns: sequence of 0 and 1 satisfying entry and exit sigs
    '''
    L = x.shape[0]
    out = np.empty(L)
    out[0] = 0.0
    out[-1] = 0.0
    i = 1
    multi = MultiSig(entry_sig,exit_sig)
    while i < L-1:
        out[i] = 0.0
        if multi.reduce_sig(entry_sig,i) and i<L-1:
            out[i] = 1.0
            trade = True
            while trade and i<L-2:
                i += 1
                out[i] = 1.0
                if multi.reduce_sig(exit_sig,i):
                    trade = False
        i+= 1
    return out

run(x, entry_sig, exit_sig)

重申一下我在评论中所说的,|& 是按位运算符,而不是逻辑运算符。 1 & 2 输出 0/False 这不是我认为你想要的结果,所以我确保 outs 只能是 0/1 以便它产生预期输出。

您知道是因为:

out = out | s if logic == 'or_' else out & s

entry_sigexit_sig 中的时间序列顺序重要吗?

设 (output, logic) 为输出为 0 或 1 的元组,具体取决于 crossupcrossdown 如何评估元组传递的信息,逻辑为 or_and_.

tuples = ((0,'or_'),(1,'or_'),(0,'and_'))

out = tuples[0][0]
logic = tuples[0][1]
for i in range(1,len(tuples)):
    s = tuples[i][0]
    out = out | s if logic == 'or_' else out & s
    out = s
    logic = tuples[i][1]

print(out)
0

改变元组的顺序会产生另一个信号:

tuples = ((0,'or_'),(0,'and_'),(1,'or_'))

out = tuples[0][0]
logic = tuples[0][1]
for i in range(1,len(tuples)):
    s = tuples[i][0]
    out = out | s if logic == 'or_' else out & s
    out = s
    logic = tuples[i][1]

print(out)
1

性能取决于需要更新计数的次数。对所有三个时间序列使用 n=1,000,000,你的代码在我的机器上的平均 运行 时间为 0.6s,我的代码有 0.63s。

然后我稍微改变了交叉逻辑以节省 if/else 的数量,这样嵌套的 if/else 只会在时间序列交叉时触发,这只能通过一个比较来检查.这进一步将 运行-time 的差异减半,因此上面的代码现在比原始代码长 2.5% 运行-time。