获取 RandomForest 中单个树的重要性

Obtain importance of individual trees in a RandomForest

问题:有没有办法从 randomForest 对象中提取每个单独 CART 模型的变量重要性?

rf_mod$forest好像没有这个信息,the docs也不提了。


在 R 的 randomForest 包中,整个 CART 模型森林的平均变量重要性由 importance(rf_mod) 给出。

library(randomForest)

df <- mtcars

set.seed(1)
rf_mod = randomForest(mpg ~ ., 
                      data = df, 
                      importance = TRUE, 
                      ntree = 200)

importance(rf_mod)

       %IncMSE IncNodePurity
cyl  6.0927875     111.65028
disp 8.7730959     261.06991
hp   7.8329831     212.74916
drat 2.9529334      79.01387
wt   7.9015687     246.32633
qsec 0.7741212      26.30662
vs   1.6908975      31.95701
am   2.5298261      13.33669
gear 1.5512788      17.77610
carb 3.2346351      35.69909

我们也可以用getTree提取单个树结构。这是第一棵树。

head(getTree(rf_mod, k = 1, labelVar = TRUE))
  left daughter right daughter split var split point status prediction
1             2              3        wt        2.15     -3   18.91875
2             0              0      <NA>        0.00     -1   31.56667
3             4              5        wt        3.16     -3   17.61034
4             6              7      drat        3.66     -3   21.26667
5             8              9      carb        3.50     -3   15.96500
6             0              0      <NA>        0.00     -1   19.70000

一种解决方法是种植许多 CART(即 - ntree = 1),获取每棵树的可变重要性,并对结果进行平均 %IncMSE

# number of trees to grow
nn <- 200

# function to run nn CART models 
run_rf <- function(rand_seed){
  set.seed(rand_seed)
  one_tr = randomForest(mpg ~ ., 
                        data = df, 
                        importance = TRUE, 
                        ntree = 1)
  return(one_tr)
}

# list to store output of each model
l <- vector("list", length = nn)
l <- lapply(1:nn, run_rf)

提取、平均和比较步骤。

# extract importance of each CART model 
library(dplyr); library(purrr)
map(l, importance) %>% 
  map(as.data.frame) %>% 
  map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>% 
  bind_rows() %>% 
  group_by(var) %>% 
  summarise(`%IncMSE` = mean(`%IncMSE`)) %>% 
  arrange(-`%IncMSE`)

    # A tibble: 10 x 2
   var   `%IncMSE`
   <chr>     <dbl>
 1 wt        8.52 
 2 cyl       7.75 
 3 disp      7.74 
 4 hp        5.53 
 5 drat      1.65 
 6 carb      1.52 
 7 vs        0.938
 8 qsec      0.824
 9 gear      0.495
10 am        0.355

# compare to the RF model above
importance(rf_mod)

       %IncMSE IncNodePurity
cyl  6.0927875     111.65028
disp 8.7730959     261.06991
hp   7.8329831     212.74916
drat 2.9529334      79.01387
wt   7.9015687     246.32633
qsec 0.7741212      26.30662
vs   1.6908975      31.95701
am   2.5298261      13.33669
gear 1.5512788      17.77610
carb 3.2346351      35.69909

我希望能够直接从 randomForest 对象 中提取每棵树的可变重要性 而无需 这种迂回的方法涉及完全重新运行 RF 以便于重现cumulative variable importance plots like this one, and the one below shown for mtcars. Minimal example here

我知道单个树的可变重要性在统计上没有意义,我无意孤立地解释树。我想要它们的目的是为了可视化和传达随着森林中树木的增加,可变重要性度量在稳定之前跳来跳去。

我们可以通过

来简化它
library(tidyverse)
out <- map(seq_len(nn),  ~ 
          run_rf(.x) %>% 
          importance) %>%
       reduce(`+`) %>% 
       magrittr::divide_by(nn)

免责声明:这不是真正的答案,但 post 作为评论太长了。如果认为不合适,将删除。

虽然我(认为我)理解您的问题,但老实说,我不确定您的问题是否符合 statistics/ML point-of-view。以下是基于我对RF和CART明显有限的理解。也许我的 comment-post 会带来一些见解。

让我们从 Hastie, Tibshirani, Friedman, The Elements of Statistical Learning, p. 的一些关于变量重要性的一般随机森林 (RF) 理论开始。 593(bold-face 我的):

At each split in each tree, the improvement in the split-criterion is the importance measure attributed to the splitting variable, and is accumulated over all the trees in the forest separately for each variable. [...] Random forests also use the oob samples to construct a different variable-importance measure, apparently to measure the prediction strength of each variable.

因此 RF 中的可变重要性度量是 定义的 作为对所有树 累积的度量


在传统的单分类树 (CART) 中,变量重要性通过衡量节点不纯度的基尼指数来表征(参见 How to measure/rank “variable importance” when using CART? (specifically using {rpart} from R) and Carolin Strobl's PhD thesis

在 CART-like 模型中存在更复杂的表征变量重要性的措施;例如 rpart:

An overall measure of variable importance is the sum of the goodness of split measures for each split for which it was the primary variable, plus goodness * (adjusted agreement) for all splits in which it was a surrogate. In the printout these are scaled to sum to 100 and the rounded values are shown, omitting any variable whose proportion is less than 1%.


所以这里的底线如下:至少比较不容易(在最坏的情况下它没有意义)将来自单一分类树的变量度量与应用的变量重要性度量进行比较ensemble-based 方法如 RF.

这让我问:为什么 您想从 RF 模型中提取单个树的变量重要性度量?即使你想出了一种方法来计算单个树的变量重要性,我相信它们也不会很有意义,并且它们不必 "converge" 到 ensemble-accumulated 值。

训练 randomForest 模型时,计算整个森林的重要性分数并直接存储在对象内部。 Tree-specific 分数不会保留,因此无法直接从 randomForest 对象中检索。

不幸的是,您关于必须逐步构建森林的说法是正确的。好消息是 randomForest 对象是 self-contained,您不需要实现自己的 run_rf。相反,您可以使用 stats::update 到 re-fit 具有单棵树的随机森林模型,并使用 randomForest::grow 一次添加其他树:

## Starting with a random forest having a single tree,
##   grow it 9 times, one tree at a time
rfs <- purrr::accumulate( .init = update(rf_mod, ntree=1),
                          rep(1,9), randomForest::grow )

## Retrieve the importance scores from each random forest
imp <- purrr::map( rfs, ~importance(.x)[,"%IncMSE"] )

## Combine all results into a single data frame
dplyr::bind_rows( !!!imp )
# # A tibble: 10 x 10
#      cyl  disp    hp  drat    wt   qsec    vs     am    gear  carb
#    <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl>  <dbl>   <dbl> <dbl>
#  1 0      18.8  8.63 1.05   0     1.17  0     0       0      0.194
#  2 0      10.0 46.4  0.561  0    -0.299 0     0       0.543  2.05 
#  3 0      22.4 31.2  0.955  0    -0.199 0     0       0.362  5.1
#  4 1.55   24.1 23.4  0.717  0    -0.150 0     0       0.272  5.28
#  5 1.24   22.8 23.6  0.573  0    -0.178 0     0      -0.0259 4.98
#  6 1.03   26.2 22.3  0.478  1.25  0.775 0     0      -0.0216 4.1
#  7 0.887  22.5 22.5  0.406  1.79 -0.101 0     0      -0.0185 3.56
#  8 0.776  19.7 21.3  0.944  1.70  0.105 0     0.0225 -0.0162 3.11
#  9 0.690  18.4 19.1  0.839  1.51  1.24  1.01  0.02   -0.0144 2.77
# 10 0.621  18.4 21.2  0.937  1.32  1.11  0.910 0.0725 -0.114  2.49

数据框显示特征重要性如何随着每棵额外的树而变化。这是您的绘图示例的右侧面板。树木本身(对于左侧面板)可以从最终的森林中检索,它由 dplyr::last( rfs ).

给出