如何检索通向 sklearn 决策树每个叶节点的完整分支路径?
How to retrieve the full branch path leading to each leaf node of a sklearn Decision Tree?
我有这个决策树,我想从中提取每个分支。该图像是树的一部分,因为原始树要大得多,但不能很好地放在单个图像上。
我不是想像
那样打印树的规则
Rules used to predict sample 1400:
decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)
或喜欢:
The binary tree structure has 7 nodes and has the following tree structure:
node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.
我想要实现的是:
branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01
基本上,我正在尝试使用 class 和 gini 分数获取从顶部到叶节点(完整路径)的每个分支。我怎样才能做到这一点?
考虑到来自 sklearn docs 的 irist 数据集示例,我们遵循后续步骤。
1.Generate 决策树示例
代码取自docs
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)
2。检索分支路径
首先我们从树中检索以下值
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value
在 retrieve_branches
中,我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们 return 具有 [=16= 的分支路径]声明。
def retrieve_branches(number_nodes, children_left_list, children_right_list):
"""Retrieve decision tree branches"""
# Calculate if a node is a leaf
is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
# Store the branches paths
paths = []
for i in range(number_nodes):
if is_leaves_list[i]:
# Search leaf node in previous paths
end_node = [path[-1] for path in paths]
# If it is a leave node yield the path
if i in end_node:
output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
yield output
else:
# Origin and end nodes
origin, end_l, end_r = i, children_left_list[i], children_right_list[i]
# Iterate over previous paths to add nodes
for index, path in enumerate(paths):
if origin == path[-1]:
paths[index] = path + [end_l]
paths.append(path + [end_r])
# Initialize path in first iteration
if i == 0:
paths.append([i, children_left[i]])
paths.append([i, children_right[i]])
要调用 retrieve_branches
只需传递 n_nodes
、children_left
和 children_right
以及一个将存储和更新分支路径的空列表。最终展示如下图
all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>>
[[0, 1],
[0, 2, 3, 5],
[0, 2, 3, 6, 7],
[0, 2, 3, 6, 8],
[0, 2, 4, 9],
[0, 2, 4, 10]]
3。 Branch
的路径、值和 Gini
规则可以从clf.tree_
的feature
和threshold
值,以及杂质clf.tree_.impurity
和值clf.tree_.value
得到在叶节点。
for index, branch in enumerate(all_branches):
leaf_index = branch[-1]
print(f'Branch: {index}, Path: {branch}')
print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
print(f'Value: {value[leaf_index]}')
print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
print(f"---------------------------------------------------------------------------------------\n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37. 0. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0. 0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
我有这个决策树,我想从中提取每个分支。该图像是树的一部分,因为原始树要大得多,但不能很好地放在单个图像上。
我不是想像
那样打印树的规则Rules used to predict sample 1400:
decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)
或喜欢:
The binary tree structure has 7 nodes and has the following tree structure:
node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.
我想要实现的是:
branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01
基本上,我正在尝试使用 class 和 gini 分数获取从顶部到叶节点(完整路径)的每个分支。我怎样才能做到这一点?
考虑到来自 sklearn docs 的 irist 数据集示例,我们遵循后续步骤。
1.Generate 决策树示例
代码取自docs
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)
2。检索分支路径
首先我们从树中检索以下值
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value
在 retrieve_branches
中,我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们 return 具有 [=16= 的分支路径]声明。
def retrieve_branches(number_nodes, children_left_list, children_right_list):
"""Retrieve decision tree branches"""
# Calculate if a node is a leaf
is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
# Store the branches paths
paths = []
for i in range(number_nodes):
if is_leaves_list[i]:
# Search leaf node in previous paths
end_node = [path[-1] for path in paths]
# If it is a leave node yield the path
if i in end_node:
output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
yield output
else:
# Origin and end nodes
origin, end_l, end_r = i, children_left_list[i], children_right_list[i]
# Iterate over previous paths to add nodes
for index, path in enumerate(paths):
if origin == path[-1]:
paths[index] = path + [end_l]
paths.append(path + [end_r])
# Initialize path in first iteration
if i == 0:
paths.append([i, children_left[i]])
paths.append([i, children_right[i]])
要调用 retrieve_branches
只需传递 n_nodes
、children_left
和 children_right
以及一个将存储和更新分支路径的空列表。最终展示如下图
all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>>
[[0, 1],
[0, 2, 3, 5],
[0, 2, 3, 6, 7],
[0, 2, 3, 6, 8],
[0, 2, 4, 9],
[0, 2, 4, 10]]
3。 Branch
的路径、值和 Gini规则可以从clf.tree_
的feature
和threshold
值,以及杂质clf.tree_.impurity
和值clf.tree_.value
得到在叶节点。
for index, branch in enumerate(all_branches):
leaf_index = branch[-1]
print(f'Branch: {index}, Path: {branch}')
print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
print(f'Value: {value[leaf_index]}')
print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
print(f"---------------------------------------------------------------------------------------\n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37. 0. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0. 0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------