在spark中获取树模型的叶子概率
Getting the leaf probabilities of a tree model in spark
我正在尝试重构经过训练的基于火花树的模型(RandomForest 或 GBT classifiers),以便它可以在没有火花的环境中导出。 toDebugString
方法是一个很好的起点。但是,在 RandomForestClassifier
的情况下,字符串仅显示每棵树的预测 class,而没有相对概率。所以,如果你对所有树的预测进行平均,你会得到错误的结果。
一个例子。我们有一个 DecisionTree
以这种方式表示:
DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
If (feature 21 in {1.0})
Predict: 0.0
Else (feature 21 not in {1.0})
If (feature 10 in {0.0})
Predict: 0.0
Else (feature 10 not in {0.0})
Predict: 1.0
正如我们所见,在节点之后,预测似乎始终为 0 或 1。但是,如果我将这棵树应用于特征向量,我会得到类似 [0.1007, 0.8993]
的概率,它们非常有意义,因为在训练集中 negative/positive 的比例最终与示例向量在同一片叶子中与输出概率相匹配。
我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,如何? pyspark
解决方案会更好。
I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The
鉴于为实时服务 Spark(和其他)模型而设计的工具越来越多,这可能是在重新发明轮子。
但是,如果您想从普通 Python 访问模型内部,最好加载其序列化形式。
假设您有:
from pyspark.ml.classification import RandomForestClassificationModel
rf_model: RandomForestClassificationModel
path: str # Absolute path
然后保存模型:
rf_model.write().save(path)
您可以使用支持混合结构和列表类型的 Parquet reader 加载它。模型写入器写入两个节点数据:
node_data = spark.read.parquet("{}/data".format(path))
node_data.printSchema()
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- rawCount: long (nullable = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
和树元数据:
tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
tree_meta.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
前者提供了您需要的所有信息,因为预测过程基本上是 an aggregation of impurtityStats
*.
您还可以使用底层 Java 对象直接访问此数据
from collections import namedtuple
import numpy as np
LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
"InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
def jtree_to_python(jtree):
def jsplit_to_python(jsplit):
if jsplit.getClass().toString().endswith(".ContinuousSplit"):
return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
else:
jcat = jsplit.toOld().categories()
return CategoricalSplit(
jsplit.featureIndex(),
[jcat.apply(i) for i in range(jcat.length())])
def jnode_to_python(jnode):
prediction = jnode.prediction()
stats = np.array(list(jnode.impurityStats().stats()))
if jnode.numDescendants() != 0: # InternalNode
left = jnode_to_python(jnode.leftChild())
right = jnode_to_python(jnode.rightChild())
split = jsplit_to_python(jnode.split())
return InternalNode(left, right, prediction, stats, split)
else:
return LeafNode(prediction, stats)
return jnode_to_python(jtree.rootNode())
可以像这样应用于 RandomForestModel
:
nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
此外,这种结构可以很容易地用于对两棵单独的树进行预测(警告:Python 3.7+ 提前。对于遗留用法,请参阅 functools
文档):
from functools import singledispatch
@singledispatch
def should_go_left(split, vector): pass
@should_go_left.register
def _(split: CategoricalSplit, vector):
return vector[split.feature_index] in split.categories
@should_go_left.register
def _(split: ContinuousSplit, vector):
return vector[split.feature_index] <= split.threshold
@singledispatch
def predict(node, vector): pass
@predict.register
def _(node: LeafNode, vector):
return node.prediction, node.impurity
@predict.register
def _(node: InternalNode, vector):
return predict(
node.left if should_go_left(node.split, vector) else node.right,
vector
)
和森林:
from typing import Iterable, Union
def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
total = np.array([
v / v.sum() for _, v in (
predict(node, vector) for node in nodes
)
]).sum(axis=0)
return total / total.sum()
但这取决于内部 API(以及 Scala 包范围访问修饰符的弱点)并且将来可能会中断。
从 data
路径加载的 * DataFrame
可以轻松转换为与上面定义的 predict
和 predict_probability
函数兼容的结构。
from pyspark.sql.dataframe import DataFrame
from itertools import groupby
from operator import itemgetter
def model_data_to_tree(tree_data: DataFrame):
def dict_to_tree(node_id, nodes):
node = nodes[node_id]
prediction = node.prediction
impurity = np.array(node.impurityStats)
if node.leftChild == -1 and node.rightChild == -1:
return LeafNode(prediction, impurity)
else:
left = dict_to_tree(node.leftChild, nodes)
right = dict_to_tree(node.rightChild, nodes)
feature_index = node.split.featureIndex
left_value = node.split.leftCategoriesOrThreshold
split = (
CategoricalSplit(feature_index, left_value)
if node.split.numCategories != -1
else ContinuousSplit(feature_index, left_value[0])
)
return InternalNode(left, right, prediction, impurity, split)
tree_id = itemgetter("treeID")
rows = tree_data.collect()
return ([
dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
] if "treeID" in tree_data.columns
else [dict_to_tree(0, {node.id: node for node in rows})])
我正在尝试重构经过训练的基于火花树的模型(RandomForest 或 GBT classifiers),以便它可以在没有火花的环境中导出。 toDebugString
方法是一个很好的起点。但是,在 RandomForestClassifier
的情况下,字符串仅显示每棵树的预测 class,而没有相对概率。所以,如果你对所有树的预测进行平均,你会得到错误的结果。
一个例子。我们有一个 DecisionTree
以这种方式表示:
DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
If (feature 21 in {1.0})
Predict: 0.0
Else (feature 21 not in {1.0})
If (feature 10 in {0.0})
Predict: 0.0
Else (feature 10 not in {0.0})
Predict: 1.0
正如我们所见,在节点之后,预测似乎始终为 0 或 1。但是,如果我将这棵树应用于特征向量,我会得到类似 [0.1007, 0.8993]
的概率,它们非常有意义,因为在训练集中 negative/positive 的比例最终与示例向量在同一片叶子中与输出概率相匹配。
我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,如何? pyspark
解决方案会更好。
I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The
鉴于为实时服务 Spark(和其他)模型而设计的工具越来越多,这可能是在重新发明轮子。
但是,如果您想从普通 Python 访问模型内部,最好加载其序列化形式。
假设您有:
from pyspark.ml.classification import RandomForestClassificationModel
rf_model: RandomForestClassificationModel
path: str # Absolute path
然后保存模型:
rf_model.write().save(path)
您可以使用支持混合结构和列表类型的 Parquet reader 加载它。模型写入器写入两个节点数据:
node_data = spark.read.parquet("{}/data".format(path))
node_data.printSchema()
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- rawCount: long (nullable = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
和树元数据:
tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
tree_meta.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
前者提供了您需要的所有信息,因为预测过程基本上是 an aggregation of impurtityStats
*.
您还可以使用底层 Java 对象直接访问此数据
from collections import namedtuple
import numpy as np
LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
"InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
def jtree_to_python(jtree):
def jsplit_to_python(jsplit):
if jsplit.getClass().toString().endswith(".ContinuousSplit"):
return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
else:
jcat = jsplit.toOld().categories()
return CategoricalSplit(
jsplit.featureIndex(),
[jcat.apply(i) for i in range(jcat.length())])
def jnode_to_python(jnode):
prediction = jnode.prediction()
stats = np.array(list(jnode.impurityStats().stats()))
if jnode.numDescendants() != 0: # InternalNode
left = jnode_to_python(jnode.leftChild())
right = jnode_to_python(jnode.rightChild())
split = jsplit_to_python(jnode.split())
return InternalNode(left, right, prediction, stats, split)
else:
return LeafNode(prediction, stats)
return jnode_to_python(jtree.rootNode())
可以像这样应用于 RandomForestModel
:
nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
此外,这种结构可以很容易地用于对两棵单独的树进行预测(警告:Python 3.7+ 提前。对于遗留用法,请参阅 functools
文档):
from functools import singledispatch
@singledispatch
def should_go_left(split, vector): pass
@should_go_left.register
def _(split: CategoricalSplit, vector):
return vector[split.feature_index] in split.categories
@should_go_left.register
def _(split: ContinuousSplit, vector):
return vector[split.feature_index] <= split.threshold
@singledispatch
def predict(node, vector): pass
@predict.register
def _(node: LeafNode, vector):
return node.prediction, node.impurity
@predict.register
def _(node: InternalNode, vector):
return predict(
node.left if should_go_left(node.split, vector) else node.right,
vector
)
和森林:
from typing import Iterable, Union
def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
total = np.array([
v / v.sum() for _, v in (
predict(node, vector) for node in nodes
)
]).sum(axis=0)
return total / total.sum()
但这取决于内部 API(以及 Scala 包范围访问修饰符的弱点)并且将来可能会中断。
从
data
路径加载的 * DataFrame
可以轻松转换为与上面定义的 predict
和 predict_probability
函数兼容的结构。
from pyspark.sql.dataframe import DataFrame
from itertools import groupby
from operator import itemgetter
def model_data_to_tree(tree_data: DataFrame):
def dict_to_tree(node_id, nodes):
node = nodes[node_id]
prediction = node.prediction
impurity = np.array(node.impurityStats)
if node.leftChild == -1 and node.rightChild == -1:
return LeafNode(prediction, impurity)
else:
left = dict_to_tree(node.leftChild, nodes)
right = dict_to_tree(node.rightChild, nodes)
feature_index = node.split.featureIndex
left_value = node.split.leftCategoriesOrThreshold
split = (
CategoricalSplit(feature_index, left_value)
if node.split.numCategories != -1
else ContinuousSplit(feature_index, left_value[0])
)
return InternalNode(left, right, prediction, impurity, split)
tree_id = itemgetter("treeID")
rows = tree_data.collect()
return ([
dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
] if "treeID" in tree_data.columns
else [dict_to_tree(0, {node.id: node for node in rows})])