如何避免在排序的嵌套循环生成器中进行不必要的计算?

how to avoid unnecessary computations in a sorted nested loop generator?

我有一个生成器,它在参数网格上实现函数 f() 的评估:

def gen():
  for a in vec_a:
    for b in vec_b:
      for c in vec_c:
        for ...
          res = f(a, b, c, ...)
          yield (a, b, c, ...), res

vec_* 被排序使得 f() 是给定所有其他固定的相应参数的递增函数。更准确地说:

if (a2 >= a1) and (b2 >= b1) and (c2 >= c1) and ...:
  assert f(a2, b2, c2, ...) >= f(a1, b1, c1, ...)

因此,例如,如果 f(a0, b0, c0, ...) == np.inf,则:

现在我想编写一个通用生成器,它接受 gen 并根据以下规则跳过不必要的 f 计算:

  1. 如果 f(.) == np.inf 在某个时候我会打破最内层的循环
  2. 如果最内层循环在第一次迭代时被打断,我应该打破倒数第二层循环
  3. 规则 #3 适用于嵌套循环的所有其他级别

示例:如果我在网格的第一次迭代中得到 np.inf,我应该跳过整个网格并且不再对 f 执行任何计算。

示例:如果我有一个网格 [0,1,2] x [0,1,2] x [0,1,2]f(0,1,0) == np.inf,那么我跳转到评估 f(1, 0, 0).

我将如何实现这样的生成器?

使用递归来简化生成器:

def shortcircuit_eval(*vecs, f, prefix=tuple()):
    if len(vecs) == 0:
        yield prefix, f(*prefix), True
        return

    first_vec, *rest_vecs = vecs
    for i, x in enumerate(first_vec):
        inner = shortcircuit_eval(rest_vecs, f=f, prefix=prefix + (x,))
        for args, res, all_inner_first_iter in inner:
            yield args, res, all_inner_first_iter and i == 0
            if res == np.inf and all_inner_first_iter:
                return

然后就可以像shortcircuit_eval(vec_a, vec_b, vec_c, f=f)一样使用了。它确实会生成一些辅助信息作为它生成的元组的第三个元素,如果需要,您可以编写一个简短的包装器来剥离这些信息。

请注意,这是对算法的 想法 的直接实现,但这不是最优的。例如。在遍历 [0..10]^3 时,如果发现 [1, 5, 2] 是无限的,那么您知道 [3, 6, 3] 也是无限的,但是您的算法不会跳过它。如果您有兴趣,请写一个新问题询问最佳算法 :) 这个优势图应该说明您实际可以节省多少工作:

如果任何节点是无限的,您不再需要计算以该节点为祖先的整个子图。

我阅读了一些打破循环的方法(参见示例 this and this),最后得出结论,这种方法可能会导致一些混乱且可读性差的代码。

如果您查看您的网格,您可以将 INF 元素视为矩形(或超矩形,在 n 维情况下)的角,使该区域内的所有其他元素也为 INF:

在这个例子中,您将不得不记住,对于 a >= 3,您将不得不继续 运行 通过内部 b 循环,仅使用 b in (0, 1)

另请注意,您可能拥有多个 INF 元素,从而为您提供多个可能相互重叠的 INF 区域。

所以我的建议不是 break 循环,而是决定每个元素是否必须评估 f() 或者是否可以跳过它,这取决于它是否是在 INF 区域内。

三个维度的示例:

inf_list = []

def can_skip(a, b, c):
  for inf_element in inf_list:
    if a >= inf_element[0] and b >= inf_element[1] and c >= inf_element[2]:
      return True
  return False

def gen():
  for a in vec_a:
    for b in vec_b:
      for c in vec_c:
        if can_skip(a, b, c):
          # print("Skipping element ", a, b, c)
          yield (a, b, c), np.inf
        else:
          # print ("Calculating f for element", a, b, c)
          res = f(a, b, c)
          if res == np.inf:
            inf_list.append((a, b, c))
          yield (a, b, c), res