删除给定(非二叉)树的一个或多个叶子后获取所有树

Obtain all trees after deleting one or more leaves of a given (non-binary) tree

给定一棵树(非二进制),获得 all 删除 one 后出现的树的最佳方法是什么或 more leaf 个来自原始树的节点?[​​=12=]

我正在寻找算法或伪代码的简单 Pythonic 实现(不使用任何图形库)。

举个例子:

对于一棵树:

     A 
   / | \
  B  C  D
 / \     \
E   F     G

我想得到以下树作为输出:

     A 
   / | \
  B  C  D


     A 
   /   \
  B     D


     A 
   /   \
  B     D
 / \     \
E   F     G


     A 
   / | \
  B  C  D
 /       \
E         G


     A 
   / | \
  B  C  D
   \     \
    F     G

.....等等。

一种方法是识别输入树的所有叶子,对它们进行计数,然后递增一个“掩码”数字,该数字的二进制位指示是否应包括或排除相应的叶子。最后创建一个复制给定树的函数,同时考虑到这样的掩码编号。

首先,我将假设一个通用的 Node 构造函数,它接受节点的值,但也接受将成为该节点的子节点的任意数量的参数:

class Node:
    # This constructor accepts any number of children as arguments
    def __init__(self, value, *children):
        self.value = value
        self.children = children  # This is a list (that could be empty)

然后定义这些效用函数(也可以是Nodeclass的方法):

def get_leaves(node):  # Depth first traversal yielding all leaves
    if node.children:
        for child in node.children:
            yield from get_leaves(child)
    else:
        yield node

def print_tree(root, tab=""):  # Very simple print implementation
    print(tab + str(root.value))
    for child in root.children:
        print_tree(child, tab + "  ")

现在进入实际逻辑:

def copy_tree(root, mask):
    def recur(node):
        nonlocal mask
        if not node.children:
            # Extract the next bit from the mask and let it determine whether
            # to create the leaf copy or not: 
            keep = mask & 1
            mask >>= 1
            if not keep:
                return
        # Get the children through recursion and kick out None returns
        return Node(node.value, *filter(None, map(recur, node.children)))
    
    return recur(root)

def get_trimmed_trees(root):
    leaf_count = sum(1 for _ in get_leaves(root))
    # Iterate all possible mask numbers (except the one that keeps all leaves)
    return [copy_tree(root, mask) for mask in range(2 ** leaf_count - 1)]

最后是创建您作为示例提供的树并运行算法的演示:

root = Node("A",
    Node("B", Node("E"), Node("F")),
    Node("C"),
    Node("D", Node("G"))
)

print_tree(root)
for trimmed_root in get_trimmed_trees(root):
    print("===")
    print_tree(trimmed_root)