遍历sklearn决策树

Traversal of sklearn decision tree

sklearn决策树的广度优先搜索遍历怎么做?

在我的代码中,我尝试了 sklearn.tree_ 库并使用了各种函数,例如 tree_.feature 和 tree_.threshold 来理解树的结构。但是这些函数做的是树的dfs遍历,如果我想做bfs应该怎么做呢?

假设

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

这是我的 classifier,生成的决策树是

然后我使用以下函数遍历了树

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

产生的输出是

array([['address', 'age', None, None, 'age', None, None],
       [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)

其中第一个数组是节点的特征,或者如果它没有叶子,则它被标记为 none,第二个数组是特征节点的阈值,对于 class 节点,它是 class但这是树的 dfs 遍历我想做 bfs 遍历我应该怎么做?

由于我是堆栈溢出的新手,请建议如何改进问题描述以及我应该添加哪些其他信息以进一步解释我的问题。

X_train(示例)

y_train(示例)

应该这样做:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

这使用 deque 来保存要处理的节点堆栈。由于我们是从左边移除元素,然后在右边添加元素,所以这应该代表广度优先遍历。


为了实际使用,我建议你把它变成一个生成器:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

那么,您只需要对原始函数进行最小的更改:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)