如何显示测试样本的决策树路径?
How to display the path of a Decision Tree for test samples?
我正在使用 DecisionTreeClassifier from scikit-learn to classify some multiclass data. I found many posts describing how to display the decision tree path, like , here, and . However, all of them describe how to display the tree for the trained data. It makes sense, because export_graphviz
只需要一个合适的模型。
我的问题是如何在 测试样本 上可视化树(最好是 export_graphviz
)。 IE。在用 clf.fit(X[train], y[train])
拟合模型,然后用 clf.predict(X[test])
预测测试数据的结果后,我想可视化用于预测样本 X[test]
的决策路径。有办法吗?
编辑:
我看到可以使用 decision_path 打印路径。如果有一种方法可以从 export_graphviz
开始显示 DOT
输出,那就太好了。
为了获得决策树中特定样本所采用的路径,您可以使用 decision_path
。它 returns 一个稀疏矩阵,其中包含所提供样本的决策路径。
然后可以将这些决策路径用于 color/label 通过 pydot
生成的树。这需要覆盖颜色和标签(这会导致一些难看的代码)。
备注
decision_path
可以从训练集中取样本或新值
- 您可以随心所欲地使用颜色,并根据样本数量或可能需要的任何其他可视化效果更改颜色
例子
在下面的示例中,已访问的节点为绿色,所有其他节点为白色。
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
if node.get_attributes().get('label') is None:
continue
if 'samples = ' in node.get_attributes()['label']:
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = 0'
node.set('label', '<br/>'.join(labels))
node.set_fillcolor('white')
samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)
for decision_path in decision_paths:
for n, node_value in enumerate(decision_path.toarray()[0]):
if node_value == 0:
continue
node = graph.get_node(str(n))[0]
node.set_fillcolor('green')
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)
node.set('label', '<br/>'.join(labels))
filename = 'tree.png'
graph.write_png(filename)
我正在使用 DecisionTreeClassifier from scikit-learn to classify some multiclass data. I found many posts describing how to display the decision tree path, like export_graphviz
只需要一个合适的模型。
我的问题是如何在 测试样本 上可视化树(最好是 export_graphviz
)。 IE。在用 clf.fit(X[train], y[train])
拟合模型,然后用 clf.predict(X[test])
预测测试数据的结果后,我想可视化用于预测样本 X[test]
的决策路径。有办法吗?
编辑:
我看到可以使用 decision_path 打印路径。如果有一种方法可以从 export_graphviz
开始显示 DOT
输出,那就太好了。
为了获得决策树中特定样本所采用的路径,您可以使用 decision_path
。它 returns 一个稀疏矩阵,其中包含所提供样本的决策路径。
然后可以将这些决策路径用于 color/label 通过 pydot
生成的树。这需要覆盖颜色和标签(这会导致一些难看的代码)。
备注
decision_path
可以从训练集中取样本或新值- 您可以随心所欲地使用颜色,并根据样本数量或可能需要的任何其他可视化效果更改颜色
例子
在下面的示例中,已访问的节点为绿色,所有其他节点为白色。
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
if node.get_attributes().get('label') is None:
continue
if 'samples = ' in node.get_attributes()['label']:
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = 0'
node.set('label', '<br/>'.join(labels))
node.set_fillcolor('white')
samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)
for decision_path in decision_paths:
for n, node_value in enumerate(decision_path.toarray()[0]):
if node_value == 0:
continue
node = graph.get_node(str(n))[0]
node.set_fillcolor('green')
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)
node.set('label', '<br/>'.join(labels))
filename = 'tree.png'
graph.write_png(filename)