如何检索通向 sklearn 决策树每个叶节点的完整分支路径?

How to retrieve the full branch path leading to each leaf node of a sklearn Decision Tree?

我有这个决策树,我想从中提取每个分支。该图像是树的一部分,因为原始树要大得多,但不能很好地放在单个图像上。

我不是想像

那样打印树的规则
Rules used to predict sample 1400:

decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)

或喜欢:

The binary tree structure has 7 nodes and has the following tree structure:

node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
    node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
    node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.

我想要实现的是:

branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01

基本上,我正在尝试使用 class 和 gini 分数获取从顶部到叶节点(完整路径)的每个分支。我怎样才能做到这一点?

考虑到来自 sklearn docs 的 irist 数据集示例,我们遵循后续步骤。

1.Generate 决策树示例

代码取自docs

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)

2。检索分支路径

首先我们从树中检索以下值

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value

retrieve_branches 中,我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们 return 具有 [=16= 的分支路径]声明。

def retrieve_branches(number_nodes, children_left_list, children_right_list):
    """Retrieve decision tree branches"""
    
    # Calculate if a node is a leaf
    is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
    
    # Store the branches paths
    paths = []
    
    for i in range(number_nodes):
        if is_leaves_list[i]:
            # Search leaf node in previous paths
            end_node = [path[-1] for path in paths]

            # If it is a leave node yield the path
            if i in end_node:
                output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
                yield output

        else:
            
            # Origin and end nodes
            origin, end_l, end_r = i, children_left_list[i], children_right_list[i]

            # Iterate over previous paths to add nodes
            for index, path in enumerate(paths):
                if origin == path[-1]:
                    paths[index] = path + [end_l]
                    paths.append(path + [end_r])

            # Initialize path in first iteration
            if i == 0:
                paths.append([i, children_left[i]])
                paths.append([i, children_right[i]])

要调用 retrieve_branches 只需传递 n_nodeschildren_leftchildren_right 以及一个将存储和更新分支路径的空列表。最终展示如下图

all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>> 
[[0, 1],
 [0, 2, 3, 5],
 [0, 2, 3, 6, 7],
 [0, 2, 3, 6, 8],
 [0, 2, 4, 9],
 [0, 2, 4, 10]]

3。 Branch

的路径、值和 Gini

规则可以从clf.tree_featurethreshold值,以及杂质clf.tree_.impurity和值clf.tree_.value得到在叶节点。

for index, branch in enumerate(all_branches):
    leaf_index = branch[-1]
    print(f'Branch: {index}, Path: {branch}')
    print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
    print(f'Value: {value[leaf_index]}')
    print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
    print(f"---------------------------------------------------------------------------------------\n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37.  0.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0.  0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------