根据提供的参数优化 python 代码以更快地生成排列和索引

Optimization of a python code according to provided parameters to generate permutation and index faster

如何优化 以使用给定的 256 字节为例?
示例字节是:

bytez = [197 215 20 156 94 67 20 100 27 208 186 248 71 48 128 75 7 165 148 223 94 163 233 15 161 104 246 66 242 142 118 165 204 0 252 22 233 28 136 197 113 122 72 229 11 91 133 142 20 204 119 211 170 104 63 39 46 68 150 123 148 95 96 95 17 133 243 35 45 66 76 19 41 200 141 120 110 215 140 230 252 182 42 166 59 249 171 97 124 8 138 59 112 191 87 170 218 31 51 74 112 23 37 13 63 96 61 200 110 189 59 18 11 99 94 63 245 107 31 11 217 51 133 35 113 36 154 179 223 92 31 239 20 51 200 102 133 183 240 88 104 29 81 122 28 246 161 90 89 6 241 241 19 40 43 248 78 6 234 40 171 23 143 70 122 246 180 148 183 67 158 198 212 41 0 98 171 81 122 114 229 193 213 212 65 72 120 191 228 32 132 172 88 100 104 119 253 166 159 242 246 6 66 190 31 57 175 105 161 1 109 8 1 50 97 60 101 25 131 93 51 243 203 41 11 140 231 59 131 68 177 58 80 142 9 21 20 106 132 161 187 21 253 234 222 190 91 106 192 149 4 70 77 139 170 172]  
distinct = 156  

我希望 item_atitem_index 这两种方法都能处理 256 字节的列表,而不是字符串或整数。

更多详情:
对于 item_at,将提供 256 字节的列表l256b 作为 index,对于 distinct(存在不同字节数在提供的 l256b) 中,方法的用户将提供一个值 0<=x<256
不需要参数 alphabetlength,因为它们始终是常量,并且 alphabet 的字节 0<=x<256length 的字节 256.
item_at 必须 return 一个 256 字节的列表,这是提供的索引的排列。

对于 item_index,将提供 256 字节的列表l256b 作为 item(排列),对于 distinct(所提供的 l256b 中存在不同的字节数)该方法的用户将提供 0<=d<256 之间的值。
不需要参数 alphabetlength,因为它们始终是常量,并且 alphabet 的字节 0<=x<256length 的字节 256.
item_index 必须 return 一个 256 字节的列表,这是提供的排列的索引。

1 - 处理列表

有两个细节,一个是该解决方案设计用于处理字符串,但可以轻松修改以处理列表

2 - 仅跟踪前缀的计数

该代码正在使用构建整个前缀,从而在每次调用时扩展它,这会导致大量复制。在函数 item_index 中,此前缀仅用于了解是否使用了给定符号。相反,可以做的是有一个字典,其中包含每个符号在前缀中的数字。然后不检查 d in prefix 你使用 prefixCount[d] != 0.

3 - 调整缓存大小

你可以看到解决方案使用了LRU cache,这种类型的缓存默认只会记住最近的128个元素。可以用lru_cache(maxsize=None)或者干脆cache()修饰函数,如果知道输入的最大长度是256,用lru_cache(maxsize=256**2)就够了。

@lru_cache(maxsize=256**2)
def count_seq(n_symbols, length, distinct, used=0):
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return \
          count_seq(n_symbols, length-1, distinct-0, used+0) * used + \
          count_seq(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
def item_at(idx, alphabet, length, distinct, used=0, prefix=None):
    if prefix is None:
        prefix = [];
    if distinct < 0:
        return
    if length == 0:
        return prefix
    else:
        for d in alphabet:
            if d in prefix:
                branch_count = count_seq(len(alphabet), 
                                         length-1, distinct, used)
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d);
                    return item_at(idx, alphabet, 
                                   length-1, distinct, used, prefix)
            else:
                branch_count = count_seq(len(alphabet),
                                         length-1, distinct-1, used+1)
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d);
                    return item_at(idx, alphabet,
                                   length-1, distinct-1, used+1, prefix)

def item_index(item, alphabet, length, distinct, used=0, prefixCount=None, idx=0):
    if prefixCount is None:
        prefixCount = {a: 0 for a in alphabet}
    if distinct < 0:
        return 0
    if length == 0:
        return 0
    else:
        offset = 0
        for d in alphabet:
            if prefixCount[d] != 0:
                if d == item[idx]:
                    prefixCount[d] += 1
                    return offset + item_index(item, alphabet, 
                               length-1, distinct, used, prefixCount, idx+1)
                else:
                    offset += count_seq(len(alphabet), 
                                length-1, distinct, used)
            else:
                if d == item[idx]:
                    prefixCount[d] += 1;
                    return offset + item_index(item, alphabet, 
                             length-1, distinct-1, used+1, prefixCount, idx+1)
                else:
                    offset += count_seq(len(alphabet), 
                                 length-1, distinct-1, used+1)

然后它将在现代计算机中 运行 几毫秒后

迭代实现

我正在写一个 class,您将实例化它,给出一个字母表和您想要的不同符号的数量,在这种情况下,distinct + used 在所有重复中都是不变的。 count_seq 的结果在构造时预先计算在矩阵 C 中。 item_atitem_index 方法是基于 C.

计算结果的迭代实现

在我看来,这变得不那么可读了,因为在递归实现中,一切都是根据具有明确概念关联的函数调用来表达的。

class SequenceLookup:
    def __init__(self, alphabet, length, distinct):
        self.alphabet = list(alphabet)
        self.distinct = distinct
        n_symbols = len(alphabet)
        c = [0] * distinct + [1, 0]
        C = [c]
        for l in range(2,length+1):
            c = [
                c[d] * d + c[d+1] * (n_symbols - d)
                for d in range(distinct+1)
            ] + [0]
            C.append(c)
        self.C = C
    
    def item_index(self, item):
        length = len(item)
        offset = 0
        seen = set()
        for i,di in enumerate(item):
            for d in self.alphabet:
                if d == di:
                    break;
                if d in seen:
                    offset += self.C[length-1-i][len(seen)]
                else:
                    offset += self.C[length-1-i][len(seen)+1]
            seen.add(di)
        return offset
    def item_at(self, idx, length):
        seen = set()
        prefix = []
        for i in range(length):
            for d in self.alphabet:
                if d in prefix:
                    branch_count = self.C[length-1-i][len(seen)]
                else:
                    branch_count = self.C[length-1-i][len(seen)+1]
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d)
                    seen.add(d)
                    break
        return prefix
bytez=[197, 215, 20, 156, 94, 67, 20, 100, 27, 208, 186, 248, 
       71, 48, 128, 75, 7, 165, 148, 223, 94, 163, 233, 15,
       161, 104, 246, 66, 242, 142, 118, 165, 204, 0, 252,
       22, 233, 28, 136, 197, 113, 122, 72, 229, 11, 91, 133,
       142, 20, 204, 119, 211, 170, 104, 63, 39, 46, 68, 150,
       123, 148, 95, 96, 95, 17, 133, 243, 35, 45, 66, 76, 19,
       41, 200, 141, 120, 110, 215, 140, 230, 252, 182, 42, 
       166, 59, 249, 171, 97, 124, 8, 138, 59, 112, 191, 87, 
       170, 218, 31, 51, 74, 112, 23, 37, 13, 63, 96, 61, 200, 
       110, 189, 59, 18, 11, 99, 94, 63, 245, 107, 31, 11, 
       217, 51, 133, 35, 113, 36, 154, 179, 223, 92, 31, 239, 
       20, 51, 200, 102, 133, 183, 240, 88, 104, 29, 81, 122,
       28, 246, 161, 90, 89, 6, 241, 241, 19, 40, 43, 248, 78,
       6, 234, 40, 171, 23, 143, 70, 122, 246, 180, 148, 183,
       67, 158, 198, 212, 41, 0, 98, 171, 81, 122, 114, 229,
       193, 213, 212, 65, 72, 120, 191, 228, 32, 132, 172, 88,
       100, 104, 119, 253, 166, 159, 242, 246, 6, 66, 190, 31,
       57, 175, 105, 161, 1, 109, 8, 1, 50, 97, 60, 101, 25,
       131, 93, 51, 243, 203, 41, 11, 140, 231, 59, 131, 68,
       177, 58, 80, 142, 9, 21, 20, 106, 132, 161, 187, 21, 253, 
       234, 222, 190, 91, 106, 192, 149, 4, 70, 77, 139, 170, 172]
v = SequenceLookup(range(256), len(bytez), len(set(bytez)))
%%timeit
v = SequenceLookup(range(256), len(bytez), len(set(bytez)))

11.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#%%timeit
v.item_index(bytez)

7.57 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#%%timeit
v.item_at(t, 256)

33.6 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

专用于字节

使用固定字母表的实现 [0,..255]

class SequenceLookup:
    def __init__(self, length, distinct):
        self.distinct = distinct
        c = [0] * distinct + [1, 0]
        C = [c]
        for l in range(2,length+1):
            c = [
                c[d] * d + c[d+1] * (256 - d)
                for d in range(distinct+1)
            ] + [0]
            C.append(c)
        self.C = C
    
    def item_index(self, item):
        length = len(item)
        offset = 0
        seen = set()
        for i,di in enumerate(item):
            for d in range(256):
                if d == di:
                    break;
                if d in seen:
                    offset += self.C[length-1-i][len(seen)]
                else:
                    offset += self.C[length-1-i][len(seen)+1]
            seen.add(di)
        return offset

    def item_at(self, idx, length):
        seen = [0] * 256
        prefix = [0] * length
        used = 0
        for i in range(length):
            for d in range(256):
                if seen[d] != 0:
                    branch_count = self.C[length-1-i][used]
                else:
                    branch_count = self.C[length-1-i][used+1]
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix[i] = d;
                    if seen[d] == 0:
                        used += 1;
                    seen[d] = 1
                    break
        return prefix

使用此实现构造和 item_index 花费的时间基本相同,但 item_at 运行 在我的测试中更快

6.32 ms ± 91.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这当然会有所不同,因此您可能想自己尝试使用不同数据结构的相同算法。