找到具有最大总和的根到叶路径 - 无法比较问题

Find the root-to-leaf path with the max sum - can't compare issues

我现在正在寻找具有最大总和的根到叶路径。我的做法是:

def max_sum(root):
    _max = 0
    find_max(root, _max, 0)
    return _max

def find_max(node, max_sum, current_sum):
    if not node:
        return 0
    current_sum += node.value
    if not node.left and not node.right:
        print(current_sum, max_sum, current_sum > max_sum)
        max_sum = max(max_sum, current_sum)
    if node.left:
        find_max(node.left, max_sum, current_sum)
    if node.right:
        find_max(node.right, max_sum, current_sum)
    current_sum -= node.value

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(1)
    root.left = TreeNode(7)
    root.right = TreeNode(9)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(2)
    root.right.right = TreeNode(7)

    print(max_sum(root))

    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(max_sum(root))

main()

输出:

12 0 True
13 0 True
12 0 True
17 0 True
0
23 0 True
23 0 True
18 0 True
0

Process finished with exit code 0

预期输出为 17 和 23。

我想确认为什么我的方法不能比较max_sumcurrent_sum?即使它在比较中返回了 true,但不会更新 max_sum。感谢您的帮助。

错误修复

这是我们可以修复您的 find_sum 函数的方法 -

def find_max(node, current_sum = 0):
  # empty tree
  if not node:
      return current_sum

  # branch
  elif node.left or node.right:
    next_sum = current_sum + node.value
    left = find_max(node.left, next_sum)
    right = find_max(node.right, next_sum)
    return max(left, right)
  
  # leaf
  else:
    return current_sum + node.value
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(find_max(t1))
print(find_max(t2))  
17
23

看到过程

我们可以通过跟踪示例之一来可视化计算过程,find_max(t2) -

             12
          /       \
         7         1
        / \       / \
       4   None  10  5
     find_max(12,0)
          /      \
         7        1
        / \      / \
       4  None  10  5
          find_max(12,0)
          /           \
max(find_max(7,12), find_max(1,12))
     / \                / \
    4  None           10   5
                                find_max(12,0)
                           /                         \
         max(find_max(7,12),                          find_max(1,12))
            /              \                          /             \
max(find_max(4,19), find_max(None,19))  max(find_max(10,13), find_max(5,13))
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
      /              \                /             \
 max(23,             19)         max(23,            18)
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
            |                                |
           23                               23
            find_max(12,0)
            /            \     
     max(23,              23)  
            find_max(12,0)
                 |
                23
23

细化

不过我认为我们可以改进。就像我们在您的 中所做的那样,我们可以再次使用数学归纳法 -

  1. 如果输入树 t 为空,return 为空结果
  2. (归纳)t不为空。如果存在子问题 t.leftt.right 分支,将 t.value 添加到累积结果 r 并在每个
  3. 上重复
  4. (归纳)t不为空且t.leftt.right均为空;已到达叶节点;将 t.value 添加到累加结果 r 并得出总和
def sum_branch (t, r = 0):
  if not t:
    return                                       # (1)
  elif t.left or t.right:
    yield from sum_branch(t.left, r + t.value)   # (2)
    yield from sum_branch(t.right, r + t.value)
  else:
    yield r + t.value                            # (3)
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(max(sum_branch(t1)))
print(max(sum_branch(t2)))
17
23

仿制药

也许更有趣的写法是先写一个通用的paths函数-

def paths (t, p = []):
  if not t:
    return                                     # (1)
  elif t.left or t.right:
    yield from paths(t.left, [*p, t.value])    # (2)
    yield from paths(t.right, [*p, t.value])
  else:
    yield [*p, t.value]                        # (3)

然后我们可以将最大和问题作为通用函数 maxsumpaths -

的组合来解决
print(max(sum(x) for x in paths(t1)))
print(max(sum(x) for x in paths(t2)))
17
23