将递归 python 代码转换为非递归版本
Convert a recursive python code to a non-recursive version
此处提供的代码有效,除非我们开始增加 distinct 和 n 符号和长度,例如,在我的计算机上 n_symbols=512,length=512,distinct=300 最终出现此错误RecursionError: 最大递归深度在比较中超出 如果我增加 lru_cache 值然后溢出错误。
我想要的是此代码的非递归版本。
from functools import lru_cache
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
'''
- n_symbols: number of symbols in the alphabet
- length: the number of symbols in each sequence
- distinct: the number of distinct symbols in each sequence
'''
if distinct < 0:
return 0
if length == 0:
return 1 if distinct == 0 else 0
else:
return \
get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
然后
get_permutations_count(n_symbols=300, length=300, distinct=270)
在 ~0.5 秒内给出答案
2729511887951350984580070745513114266766906881300774347439917775
7093985721949669285469996223829969654724957176705978029888262889
8157939885553971500652353177628564896814078569667364402373549268
5524290993833663948683375995196081654415976659499171897405039547
1546236260377859451955180752885715923847446106509971875543496023
2494854876774756172488117802642800540206851318332940739395445903
6305051887120804168979339693187702655904071331731936748927759927
3688881301614948043182289382736687065840703041231428800720854767
0713406956719647313048146023960093662879015837313428567467555885
3564982943420444850950866922223974844727296000000000000000000000
000000000000000000000000000000000000000000000000
这是我的:
def get_permutations_count_improved(n_symbols, length, distinct):
if distinct > length or distinct > n_symbols:
return 0
ways = [1]
for _ in range(length):
ways = [used * (distinct - d) + new
for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)
一些参数集的速度比较:
n_symbols length distinct yours mine
300 300 270 0.62 s 0.012 s (~51 times faster)
512 512 300 - 0.035 s
1024 1024 600 - 0.22 s
3000 3000 2700 - 6.0 s
在我的最后一行中,您看到我将总体结果分为三个因素:
comb(n_symbols, distinct)
用于选择实际使用 n_symbols
符号中的 distinct
。这基本上摆脱了 n_symbols
参数,或者将其视为补偿设置 n_symbols = distinct
.
factorial(distinct)
符号首先使用的顺序。这消除了您重复出现的 * (n_symbols - used)
。
ways[distinct]
是构建长度为 length
的序列的方法数,其中恰好有 distinct
个不同的符号,其中它们首先使用的顺序是固定的。
将 ways
table 想象成 two-dimensional 可能更容易:ways[length][distinct]
。但是对于更多memory-efficiency,我逐行计算,只保留最新的一行。
基准测试和一些正确性检查 (Try it online!):
from timeit import timeit
from functools import lru_cache
from math import comb, factorial
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
'''
- n_symbols: number of symbols in the alphabet
- length: the number of symbols in each sequence
- distinct: the number of distinct symbols in each sequence
'''
if distinct < 0:
return 0
if length == 0:
return 1 if distinct == 0 else 0
else:
return \
get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
def get_permutations_count_improved(n_symbols, length, distinct):
if distinct > length or distinct > n_symbols:
return 0
ways = [1]
for _ in range(length):
ways = [used * (distinct - d) + new
for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)
funcs = get_permutations_count, get_permutations_count_improved
# Check correctness
stop = 20
for a in range(stop):
for b in range(stop):
for c in range(stop):
expect = get_permutations_count(a, b, c)
result = get_permutations_count_improved(a, b, c)
assert result == expect, (a, b, c, expect, result)
# Benchmark
n_symbols, length, distinct = 300, 300, 270
#n_symbols, length, distinct = 512, 512, 300
#n_symbols, length, distinct = 1024, 1024, 600
#n_symbols, length, distinct = 3000, 3000, 2700
for func in funcs[0:] * 3:
funcs[0].cache_clear()
t = timeit(lambda: func(n_symbols, length, distinct), number=1)
print('%.3f seconds ' % t, func.__name__)
此处提供的代码有效,除非我们开始增加 distinct 和 n 符号和长度,例如,在我的计算机上 n_symbols=512,length=512,distinct=300 最终出现此错误RecursionError: 最大递归深度在比较中超出 如果我增加 lru_cache 值然后溢出错误。
我想要的是此代码的非递归版本。
from functools import lru_cache
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
'''
- n_symbols: number of symbols in the alphabet
- length: the number of symbols in each sequence
- distinct: the number of distinct symbols in each sequence
'''
if distinct < 0:
return 0
if length == 0:
return 1 if distinct == 0 else 0
else:
return \
get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
然后
get_permutations_count(n_symbols=300, length=300, distinct=270)
在 ~0.5 秒内给出答案
2729511887951350984580070745513114266766906881300774347439917775
7093985721949669285469996223829969654724957176705978029888262889
8157939885553971500652353177628564896814078569667364402373549268
5524290993833663948683375995196081654415976659499171897405039547
1546236260377859451955180752885715923847446106509971875543496023
2494854876774756172488117802642800540206851318332940739395445903
6305051887120804168979339693187702655904071331731936748927759927
3688881301614948043182289382736687065840703041231428800720854767
0713406956719647313048146023960093662879015837313428567467555885
3564982943420444850950866922223974844727296000000000000000000000
000000000000000000000000000000000000000000000000
这是我的:
def get_permutations_count_improved(n_symbols, length, distinct):
if distinct > length or distinct > n_symbols:
return 0
ways = [1]
for _ in range(length):
ways = [used * (distinct - d) + new
for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)
一些参数集的速度比较:
n_symbols length distinct yours mine
300 300 270 0.62 s 0.012 s (~51 times faster)
512 512 300 - 0.035 s
1024 1024 600 - 0.22 s
3000 3000 2700 - 6.0 s
在我的最后一行中,您看到我将总体结果分为三个因素:
comb(n_symbols, distinct)
用于选择实际使用n_symbols
符号中的distinct
。这基本上摆脱了n_symbols
参数,或者将其视为补偿设置n_symbols = distinct
.factorial(distinct)
符号首先使用的顺序。这消除了您重复出现的* (n_symbols - used)
。ways[distinct]
是构建长度为length
的序列的方法数,其中恰好有distinct
个不同的符号,其中它们首先使用的顺序是固定的。
将 ways
table 想象成 two-dimensional 可能更容易:ways[length][distinct]
。但是对于更多memory-efficiency,我逐行计算,只保留最新的一行。
基准测试和一些正确性检查 (Try it online!):
from timeit import timeit
from functools import lru_cache
from math import comb, factorial
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
'''
- n_symbols: number of symbols in the alphabet
- length: the number of symbols in each sequence
- distinct: the number of distinct symbols in each sequence
'''
if distinct < 0:
return 0
if length == 0:
return 1 if distinct == 0 else 0
else:
return \
get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
def get_permutations_count_improved(n_symbols, length, distinct):
if distinct > length or distinct > n_symbols:
return 0
ways = [1]
for _ in range(length):
ways = [used * (distinct - d) + new
for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)
funcs = get_permutations_count, get_permutations_count_improved
# Check correctness
stop = 20
for a in range(stop):
for b in range(stop):
for c in range(stop):
expect = get_permutations_count(a, b, c)
result = get_permutations_count_improved(a, b, c)
assert result == expect, (a, b, c, expect, result)
# Benchmark
n_symbols, length, distinct = 300, 300, 270
#n_symbols, length, distinct = 512, 512, 300
#n_symbols, length, distinct = 1024, 1024, 600
#n_symbols, length, distinct = 3000, 3000, 2700
for func in funcs[0:] * 3:
funcs[0].cache_clear()
t = timeit(lambda: func(n_symbols, length, distinct), number=1)
print('%.3f seconds ' % t, func.__name__)