分布的迭代排列
Iterative Permutations of Distribution
我正在尝试生成各种分布的所有可能组合。
例如,假设您有 5 个积分可用于 4 个类别,但您最多只能在任何给定类别上花费 2 个积分。
在这种情况下,所有可能的解决方案如下:
[0, 1, 2, 2]
[0, 2, 1, 2]
[0, 2, 2, 1]
[1, 0, 2, 2]
[1, 1, 1, 2]
[1, 1, 2, 1]
[1, 2, 0, 2]
[1, 2, 1, 1]
[1, 2, 2, 0]
[2, 0, 1, 2]
[2, 0, 2, 1]
[2, 1, 0, 2]
[2, 1, 1, 1]
[2, 1, 2, 0]
[2, 2, 0, 1]
[2, 2, 1, 0]
我已经成功地创建了一个递归函数来实现这一点,但是对于更多的类别,生成它需要非常长的时间。我曾尝试制作一个迭代函数,希望能加快速度,但我似乎无法让它解释类别最大值。
这是我的递归函数(count = points,dist = zero-filled array w/ same size as max_allo)
def distribute_recursive(count, max_allo, dist, depth=0):
for ration in range(max(count - sum(max_allo[depth + 1:]), 0), min(count, max_allo[depth]) + 1):
dist[depth] = ration
count -= ration
if depth + 1 < len(dist):
distribute_recursive(count, max_allo, dist, depth + 1)
else:
print(dist)
count += ration
您可以使用生成器函数进行递归,同时应用额外的逻辑来减少所需的递归调用次数:
def listings(_cat, points, _max, current = []):
if len(current) == _cat:
yield current
else:
for i in range(_max+1):
if sum(current+[i]) <= points:
if sum(current+[i]) == points or len(current+[i]) < _cat:
yield from listings(_cat, points, _max, current+[i])
print(list(listings(4, 5, 2)))
输出:
[[0, 1, 2, 2], [0, 2, 1, 2], [0, 2, 2, 1], [1, 0, 2, 2], [1, 1, 1, 2], [1, 1, 2, 1], [1, 2, 0, 2], [1, 2, 1, 1], [1, 2, 2, 0], [2, 0, 1, 2], [2, 0, 2, 1], [2, 1, 0, 2], [2, 1, 1, 1], [2, 1, 2, 0], [2, 2, 0, 1], [2, 2, 1, 0]]
虽然不清楚您的解决方案在多大的类别大小下会显着变慢,但对于最大 24
的类别大小,此解决方案将 运行 一秒以内,搜索总共五个点最大插槽值为 2。请注意,对于较大的点和槽值,每秒计算的可能类别大小的数量会增加:
import time
def timeit(f):
def wrapper(*args):
c = time.time()
_ = f(*args)
return time.time() - c
return wrapper
@timeit
def wrap_calls(category_size:int) -> float:
_ = list(listings(category_size, 5, 2))
benchmark = 0
category_size = 1
while benchmark < 1:
benchmark = wrap_calls(category_size)
category_size += 1
print(category_size)
输出:
24
递归并不慢
递归不是让它变慢的原因;考虑一个更好的算法
def dist (count, limit, points, acc = []):
if count is 0:
if sum (acc) is points:
yield acc
else:
for x in range (limit + 1):
yield from dist (count - 1, limit, points, acc + [x])
您可以将生成的结果收集在一个列表中
print (list (dist (count = 4, limit = 2, points = 5)))
p运行无效组合
上面,我们使用了 limit + 1
的固定范围,但如果我们生成具有(例如)limit = 2
和 points = 5
的组合,请观察会发生什么......
[ 2, ... ] # 3 points remaining
[ 2, 2, ... ] # 1 point remaining
在这一点上,使用 limit + 1
([ 0, 1, 2 ]
) 的固定范围是愚蠢的,因为我们知道我们只剩下 1 点可以花费了。这里唯一剩下的选项是 0
或 1
...
[ 2, 2, 1 ... ] # 0 points remaining
上面我们知道我们可以使用空范围 [ 0 ]
因为没有剩余的点数可以花费。这将阻止我们尝试验证像
这样的组合
[ 2, 2, 2, ... ] # -1 points remaining
[ 2, 2, 2, 0, ... ] # -1 points remaining
[ 2, 2, 2, 1, ... ] # -2 points remaining
[ 2, 2, 2, 2, ... ] # -3 points remaining
如果 count
非常大,这可以排除 大量 无效组合
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, ... ] # -15 points remaining
要实现此优化,我们可以向 dist
函数添加另一个参数,但在 5 个参数时,它会开始看起来很乱。相反,我们引入了一个辅助函数来控制 loop
。添加我们的优化,我们将固定范围换成 min (limit, remaining) + 1
的 动态 范围。最后,因为我们知道分配了多少点,所以我们不再需要测试每个组合的 sum
;从我们的算法中删除了另一个昂贵的操作
# revision: prune invalid combinations
def dist (count, limit, points):
def loop (count, remaining, acc):
if count is 0:
if remaining is 0:
yield acc
else:
for x in range (min (limit, remaining) + 1):
yield from loop (count - 1, remaining - x, acc + [x])
yield from loop (count, points, [])
基准测试
在下面的基准测试中,我们程序的第一个版本被重命名为 dist1
,而更快的程序使用动态范围 dist2
。我们设置了三个测试,small
、medium
和 large
def small (prg):
return list (prg (count = 4, limit = 2, points = 5))
def medium (prg):
return list (prg (count = 8, limit = 3, points = 7))
def large (prg):
return list (prg (count = 16, limit = 5, points = 10))
现在我们 运行 测试,将每个程序作为参数传递。请注意 large
测试,仅完成 1 次通过,因为 dist1
需要一段时间才能生成结果
print (timeit ('small (dist1)', number = 10000, globals = globals ()))
print (timeit ('small (dist2)', number = 10000, globals = globals ()))
print (timeit ('medium (dist1)', number = 100, globals = globals ()))
print (timeit ('medium (dist2)', number = 100, globals = globals ()))
print (timeit ('large (dist1)', number = 1, globals = globals ()))
print (timeit ('large (dist2)', number = 1, globals = globals ()))
small
测试的结果表明,p运行ing 无效组合并没有太大的区别。然而,在 medium
和 large
的情况下,差异是巨大的。我们的旧程序处理大集合需要 30 多分钟,但使用新程序只需 1 秒多一点!
<i>dist1 小</i> 0.8512216459494084
<i>dist2 小</i> 0.8610155049245805 <i>(0.98 倍加速)</i>
<i>dist1 中</i> 6.142372329952195
<i>dist2 中</i> 0.9355670949444175 <i>(6.57 倍加速)</i>
<i>dist1 大</i>1933.0877765258774
<i>dist2 大</i> 1.4107366011012346 <i>(1370.26 倍加速)</i>
参考框架,每个结果的大小打印在下面
print (len (small (dist2))) # 16 (this is the example in your question)
print (len (medium (dist2))) # 2472
print (len (large (dist2))) # 336336
检查我们的理解
在使用 count = 12
和 limit = 5
的 large
基准测试中,我们使用未优化的程序迭代 512,或 244,140,625 可能种组合。使用我们优化的程序,我们跳过所有无效组合,从而产生 336,336 个有效答案。仅通过分析组合计数,我们发现 可能的 组合中有 99.86% 是惊人的无效组合。如果每个组合的分析花费相同的时间,我们可以预期我们的优化程序的性能至少提高 725.88 倍,因为无效组合 p运行ing.
在 large
基准测试中,优化后的程序速度提高了 1370.26 倍,达到甚至超出了我们的预期。额外的加速可能是由于我们消除了对 sum
的调用
呵呵
要证明此技术适用于极大的数据集,请考虑 huge
基准。我们的程序在 716 或 33,232,930,569,601 种可能性中找到 17,321,844 种有效组合。
在这个测试中,我们的优化程序 p运行es 了 99.99479% 的无效组合。将这些数字与之前的数据集相关联,我们估计优化后的程序 运行 比未优化的版本快 1,918,556.16 倍。
该基准测试使用未优化程序的理论运行宁时间是117.60 年。优化后的程序仅需 1 分钟多一点就可以找到答案。
def huge (prg):
return list (prg (count = 16, limit = 7, points = 12))
print (timeit ('huge (dist2)', number = 1, globals = globals ()))
# 68.06868170504458
print (len (huge (dist2)))
# 17321844
我正在尝试生成各种分布的所有可能组合。
例如,假设您有 5 个积分可用于 4 个类别,但您最多只能在任何给定类别上花费 2 个积分。 在这种情况下,所有可能的解决方案如下:
[0, 1, 2, 2]
[0, 2, 1, 2]
[0, 2, 2, 1]
[1, 0, 2, 2]
[1, 1, 1, 2]
[1, 1, 2, 1]
[1, 2, 0, 2]
[1, 2, 1, 1]
[1, 2, 2, 0]
[2, 0, 1, 2]
[2, 0, 2, 1]
[2, 1, 0, 2]
[2, 1, 1, 1]
[2, 1, 2, 0]
[2, 2, 0, 1]
[2, 2, 1, 0]
我已经成功地创建了一个递归函数来实现这一点,但是对于更多的类别,生成它需要非常长的时间。我曾尝试制作一个迭代函数,希望能加快速度,但我似乎无法让它解释类别最大值。
这是我的递归函数(count = points,dist = zero-filled array w/ same size as max_allo)
def distribute_recursive(count, max_allo, dist, depth=0):
for ration in range(max(count - sum(max_allo[depth + 1:]), 0), min(count, max_allo[depth]) + 1):
dist[depth] = ration
count -= ration
if depth + 1 < len(dist):
distribute_recursive(count, max_allo, dist, depth + 1)
else:
print(dist)
count += ration
您可以使用生成器函数进行递归,同时应用额外的逻辑来减少所需的递归调用次数:
def listings(_cat, points, _max, current = []):
if len(current) == _cat:
yield current
else:
for i in range(_max+1):
if sum(current+[i]) <= points:
if sum(current+[i]) == points or len(current+[i]) < _cat:
yield from listings(_cat, points, _max, current+[i])
print(list(listings(4, 5, 2)))
输出:
[[0, 1, 2, 2], [0, 2, 1, 2], [0, 2, 2, 1], [1, 0, 2, 2], [1, 1, 1, 2], [1, 1, 2, 1], [1, 2, 0, 2], [1, 2, 1, 1], [1, 2, 2, 0], [2, 0, 1, 2], [2, 0, 2, 1], [2, 1, 0, 2], [2, 1, 1, 1], [2, 1, 2, 0], [2, 2, 0, 1], [2, 2, 1, 0]]
虽然不清楚您的解决方案在多大的类别大小下会显着变慢,但对于最大 24
的类别大小,此解决方案将 运行 一秒以内,搜索总共五个点最大插槽值为 2。请注意,对于较大的点和槽值,每秒计算的可能类别大小的数量会增加:
import time
def timeit(f):
def wrapper(*args):
c = time.time()
_ = f(*args)
return time.time() - c
return wrapper
@timeit
def wrap_calls(category_size:int) -> float:
_ = list(listings(category_size, 5, 2))
benchmark = 0
category_size = 1
while benchmark < 1:
benchmark = wrap_calls(category_size)
category_size += 1
print(category_size)
输出:
24
递归并不慢
递归不是让它变慢的原因;考虑一个更好的算法
def dist (count, limit, points, acc = []):
if count is 0:
if sum (acc) is points:
yield acc
else:
for x in range (limit + 1):
yield from dist (count - 1, limit, points, acc + [x])
您可以将生成的结果收集在一个列表中
print (list (dist (count = 4, limit = 2, points = 5)))
p运行无效组合
上面,我们使用了 limit + 1
的固定范围,但如果我们生成具有(例如)limit = 2
和 points = 5
的组合,请观察会发生什么......
[ 2, ... ] # 3 points remaining
[ 2, 2, ... ] # 1 point remaining
在这一点上,使用 limit + 1
([ 0, 1, 2 ]
) 的固定范围是愚蠢的,因为我们知道我们只剩下 1 点可以花费了。这里唯一剩下的选项是 0
或 1
...
[ 2, 2, 1 ... ] # 0 points remaining
上面我们知道我们可以使用空范围 [ 0 ]
因为没有剩余的点数可以花费。这将阻止我们尝试验证像
[ 2, 2, 2, ... ] # -1 points remaining
[ 2, 2, 2, 0, ... ] # -1 points remaining
[ 2, 2, 2, 1, ... ] # -2 points remaining
[ 2, 2, 2, 2, ... ] # -3 points remaining
如果 count
非常大,这可以排除 大量 无效组合
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, ... ] # -15 points remaining
要实现此优化,我们可以向 dist
函数添加另一个参数,但在 5 个参数时,它会开始看起来很乱。相反,我们引入了一个辅助函数来控制 loop
。添加我们的优化,我们将固定范围换成 min (limit, remaining) + 1
的 动态 范围。最后,因为我们知道分配了多少点,所以我们不再需要测试每个组合的 sum
;从我们的算法中删除了另一个昂贵的操作
# revision: prune invalid combinations
def dist (count, limit, points):
def loop (count, remaining, acc):
if count is 0:
if remaining is 0:
yield acc
else:
for x in range (min (limit, remaining) + 1):
yield from loop (count - 1, remaining - x, acc + [x])
yield from loop (count, points, [])
基准测试
在下面的基准测试中,我们程序的第一个版本被重命名为 dist1
,而更快的程序使用动态范围 dist2
。我们设置了三个测试,small
、medium
和 large
def small (prg):
return list (prg (count = 4, limit = 2, points = 5))
def medium (prg):
return list (prg (count = 8, limit = 3, points = 7))
def large (prg):
return list (prg (count = 16, limit = 5, points = 10))
现在我们 运行 测试,将每个程序作为参数传递。请注意 large
测试,仅完成 1 次通过,因为 dist1
需要一段时间才能生成结果
print (timeit ('small (dist1)', number = 10000, globals = globals ()))
print (timeit ('small (dist2)', number = 10000, globals = globals ()))
print (timeit ('medium (dist1)', number = 100, globals = globals ()))
print (timeit ('medium (dist2)', number = 100, globals = globals ()))
print (timeit ('large (dist1)', number = 1, globals = globals ()))
print (timeit ('large (dist2)', number = 1, globals = globals ()))
small
测试的结果表明,p运行ing 无效组合并没有太大的区别。然而,在 medium
和 large
的情况下,差异是巨大的。我们的旧程序处理大集合需要 30 多分钟,但使用新程序只需 1 秒多一点!
<i>dist1 小</i> 0.8512216459494084
<i>dist2 小</i> 0.8610155049245805 <i>(0.98 倍加速)</i>
<i>dist1 中</i> 6.142372329952195
<i>dist2 中</i> 0.9355670949444175 <i>(6.57 倍加速)</i>
<i>dist1 大</i>1933.0877765258774
<i>dist2 大</i> 1.4107366011012346 <i>(1370.26 倍加速)</i>
参考框架,每个结果的大小打印在下面
print (len (small (dist2))) # 16 (this is the example in your question)
print (len (medium (dist2))) # 2472
print (len (large (dist2))) # 336336
检查我们的理解
在使用 count = 12
和 limit = 5
的 large
基准测试中,我们使用未优化的程序迭代 512,或 244,140,625 可能种组合。使用我们优化的程序,我们跳过所有无效组合,从而产生 336,336 个有效答案。仅通过分析组合计数,我们发现 可能的 组合中有 99.86% 是惊人的无效组合。如果每个组合的分析花费相同的时间,我们可以预期我们的优化程序的性能至少提高 725.88 倍,因为无效组合 p运行ing.
在 large
基准测试中,优化后的程序速度提高了 1370.26 倍,达到甚至超出了我们的预期。额外的加速可能是由于我们消除了对 sum
呵呵
要证明此技术适用于极大的数据集,请考虑 huge
基准。我们的程序在 716 或 33,232,930,569,601 种可能性中找到 17,321,844 种有效组合。
在这个测试中,我们的优化程序 p运行es 了 99.99479% 的无效组合。将这些数字与之前的数据集相关联,我们估计优化后的程序 运行 比未优化的版本快 1,918,556.16 倍。
该基准测试使用未优化程序的理论运行宁时间是117.60 年。优化后的程序仅需 1 分钟多一点就可以找到答案。
def huge (prg):
return list (prg (count = 16, limit = 7, points = 12))
print (timeit ('huge (dist2)', number = 1, globals = globals ()))
# 68.06868170504458
print (len (huge (dist2)))
# 17321844