如何避免在排序的嵌套循环生成器中进行不必要的计算?
how to avoid unnecessary computations in a sorted nested loop generator?
我有一个生成器,它在参数网格上实现函数 f()
的评估:
def gen():
for a in vec_a:
for b in vec_b:
for c in vec_c:
for ...
res = f(a, b, c, ...)
yield (a, b, c, ...), res
vec_* 被排序使得 f()
是给定所有其他固定的相应参数的递增函数。更准确地说:
if (a2 >= a1) and (b2 >= b1) and (c2 >= c1) and ...:
assert f(a2, b2, c2, ...) >= f(a1, b1, c1, ...)
因此,例如,如果 f(a0, b0, c0, ...)
== np.inf,则:
f(a, b0, c0, ...) == np.inf
对于每个 a >= a0
f(a0, b, c0, ...) == np.inf
对于每个 b >= b0
f(a0, b0, c, ...) == np.inf
对于每个 c >= c0
现在我想编写一个通用生成器,它接受 gen
并根据以下规则跳过不必要的 f
计算:
- 如果
f(.) == np.inf
在某个时候我会打破最内层的循环
- 如果最内层循环在第一次迭代时被打断,我应该打破倒数第二层循环
- 规则 #3 适用于嵌套循环的所有其他级别
示例:如果我在网格的第一次迭代中得到 np.inf
,我应该跳过整个网格并且不再对 f 执行任何计算。
示例:如果我有一个网格 [0,1,2] x [0,1,2] x [0,1,2]
和 f(0,1,0) == np.inf
,那么我跳转到评估 f(1, 0, 0)
.
我将如何实现这样的生成器?
使用递归来简化生成器:
def shortcircuit_eval(*vecs, f, prefix=tuple()):
if len(vecs) == 0:
yield prefix, f(*prefix), True
return
first_vec, *rest_vecs = vecs
for i, x in enumerate(first_vec):
inner = shortcircuit_eval(rest_vecs, f=f, prefix=prefix + (x,))
for args, res, all_inner_first_iter in inner:
yield args, res, all_inner_first_iter and i == 0
if res == np.inf and all_inner_first_iter:
return
然后就可以像shortcircuit_eval(vec_a, vec_b, vec_c, f=f)
一样使用了。它确实会生成一些辅助信息作为它生成的元组的第三个元素,如果需要,您可以编写一个简短的包装器来剥离这些信息。
请注意,这是对算法的 想法 的直接实现,但这不是最优的。例如。在遍历 [0..10]^3
时,如果发现 [1, 5, 2]
是无限的,那么您知道 [3, 6, 3]
也是无限的,但是您的算法不会跳过它。如果您有兴趣,请写一个新问题询问最佳算法 :) 这个优势图应该说明您实际可以节省多少工作:
如果任何节点是无限的,您不再需要计算以该节点为祖先的整个子图。
我阅读了一些打破循环的方法(参见示例 this and this),最后得出结论,这种方法可能会导致一些混乱且可读性差的代码。
如果您查看您的网格,您可以将 INF 元素视为矩形(或超矩形,在 n 维情况下)的角,使该区域内的所有其他元素也为 INF:
在这个例子中,您将不得不记住,对于 a >= 3
,您将不得不继续 运行 通过内部 b
循环,仅使用 b in (0, 1)
。
另请注意,您可能拥有多个 INF 元素,从而为您提供多个可能相互重叠的 INF 区域。
所以我的建议不是 break
循环,而是决定每个元素是否必须评估 f()
或者是否可以跳过它,这取决于它是否是在 INF 区域内。
三个维度的示例:
inf_list = []
def can_skip(a, b, c):
for inf_element in inf_list:
if a >= inf_element[0] and b >= inf_element[1] and c >= inf_element[2]:
return True
return False
def gen():
for a in vec_a:
for b in vec_b:
for c in vec_c:
if can_skip(a, b, c):
# print("Skipping element ", a, b, c)
yield (a, b, c), np.inf
else:
# print ("Calculating f for element", a, b, c)
res = f(a, b, c)
if res == np.inf:
inf_list.append((a, b, c))
yield (a, b, c), res
我有一个生成器,它在参数网格上实现函数 f()
的评估:
def gen():
for a in vec_a:
for b in vec_b:
for c in vec_c:
for ...
res = f(a, b, c, ...)
yield (a, b, c, ...), res
vec_* 被排序使得 f()
是给定所有其他固定的相应参数的递增函数。更准确地说:
if (a2 >= a1) and (b2 >= b1) and (c2 >= c1) and ...:
assert f(a2, b2, c2, ...) >= f(a1, b1, c1, ...)
因此,例如,如果 f(a0, b0, c0, ...)
== np.inf,则:
f(a, b0, c0, ...) == np.inf
对于每个 a >= a0f(a0, b, c0, ...) == np.inf
对于每个 b >= b0f(a0, b0, c, ...) == np.inf
对于每个 c >= c0
现在我想编写一个通用生成器,它接受 gen
并根据以下规则跳过不必要的 f
计算:
- 如果
f(.) == np.inf
在某个时候我会打破最内层的循环 - 如果最内层循环在第一次迭代时被打断,我应该打破倒数第二层循环
- 规则 #3 适用于嵌套循环的所有其他级别
示例:如果我在网格的第一次迭代中得到 np.inf
,我应该跳过整个网格并且不再对 f 执行任何计算。
示例:如果我有一个网格 [0,1,2] x [0,1,2] x [0,1,2]
和 f(0,1,0) == np.inf
,那么我跳转到评估 f(1, 0, 0)
.
我将如何实现这样的生成器?
使用递归来简化生成器:
def shortcircuit_eval(*vecs, f, prefix=tuple()):
if len(vecs) == 0:
yield prefix, f(*prefix), True
return
first_vec, *rest_vecs = vecs
for i, x in enumerate(first_vec):
inner = shortcircuit_eval(rest_vecs, f=f, prefix=prefix + (x,))
for args, res, all_inner_first_iter in inner:
yield args, res, all_inner_first_iter and i == 0
if res == np.inf and all_inner_first_iter:
return
然后就可以像shortcircuit_eval(vec_a, vec_b, vec_c, f=f)
一样使用了。它确实会生成一些辅助信息作为它生成的元组的第三个元素,如果需要,您可以编写一个简短的包装器来剥离这些信息。
请注意,这是对算法的 想法 的直接实现,但这不是最优的。例如。在遍历 [0..10]^3
时,如果发现 [1, 5, 2]
是无限的,那么您知道 [3, 6, 3]
也是无限的,但是您的算法不会跳过它。如果您有兴趣,请写一个新问题询问最佳算法 :) 这个优势图应该说明您实际可以节省多少工作:
如果任何节点是无限的,您不再需要计算以该节点为祖先的整个子图。
我阅读了一些打破循环的方法(参见示例 this and this),最后得出结论,这种方法可能会导致一些混乱且可读性差的代码。
如果您查看您的网格,您可以将 INF 元素视为矩形(或超矩形,在 n 维情况下)的角,使该区域内的所有其他元素也为 INF:
在这个例子中,您将不得不记住,对于 a >= 3
,您将不得不继续 运行 通过内部 b
循环,仅使用 b in (0, 1)
。
另请注意,您可能拥有多个 INF 元素,从而为您提供多个可能相互重叠的 INF 区域。
所以我的建议不是 break
循环,而是决定每个元素是否必须评估 f()
或者是否可以跳过它,这取决于它是否是在 INF 区域内。
三个维度的示例:
inf_list = []
def can_skip(a, b, c):
for inf_element in inf_list:
if a >= inf_element[0] and b >= inf_element[1] and c >= inf_element[2]:
return True
return False
def gen():
for a in vec_a:
for b in vec_b:
for c in vec_c:
if can_skip(a, b, c):
# print("Skipping element ", a, b, c)
yield (a, b, c), np.inf
else:
# print ("Calculating f for element", a, b, c)
res = f(a, b, c)
if res == np.inf:
inf_list.append((a, b, c))
yield (a, b, c), res