多个模型:如何 select 最佳模型并进行预测

multiple models: how to select best model and make prediction

我的任务是创建许多模型,选择预测最好的模型并将数据传递给该模型进行预测。示例灵感来自 R for data science

library(modelr)
library(tidyverse)
library(gapminder)
gapminder

country_model1 <- function(df) {lm(lifeExp ~ year, data = df)}

country_model2 <- function(df) {lm(lifeExp ~ year+gdpPercap, data = df)}

country_model3 <- function(df) {lm(lifeExp ~ year+gdpPercap+pop, data = df)}

by_country <- gapminder %>% 
  group_by(country, continent) %>% 
  nest() %>% 
  mutate(model1 = map(data, country_model1),
         model2 = map(data, country_model2),
         model3 = map(data, country_model3))

所以我为每个国家/地区准备了 3 个模型。 我可以找到每个模型的 r squared,但停在这里:(

r_sq <- by_country %>% 
  mutate(glance1 = map(model1, broom::glance),
         glance2 = map(model2, broom::glance),
         glance3 = map(model3, broom::glance)) %>% 
  unnest(glance1:glance3, .drop = TRUE) %>% 
  select(country, continent, starts_with('r.sq')) 

如何整洁:

  1. select 3 个中哪一个对每个特定国家/地区的预测更好?
  2. 将新数据传递给所选模型并返回预测?

我们可以像这样为每个国家/地区确定具有最高 r^2 的模型:

best_fits <- r_sq %>%
  pivot_longer(-c(country, continent), names_to = "r_sq_version") %>%
  group_by(country, continent) %>%
  slice_max(value) %>%
  ungroup() 

毫不奇怪,第三个模型(此处称为 r.squared2 来自其在 r_sq 中的名称)始终提供最高的相关性,因为该模型需要更多的输入并具有更多的自由度。

让我们制作一些新数据,使用原始数据但将日期增加 100 年。

by_country_new <- gapminder %>%
  group_by(country, continent) %>%
  mutate(year = year + 100,
         gdpPercap = gdpPercap,
         pop = pop) %>%
  select(-lifeExp) %>%    # Presumably we don't know this and are trying to predict using known data
  nest()

然后我们可以将每个国家/地区的最佳模型应用于新数据:(感谢@mrflick

  best_fits %>%
    left_join(by_country) %>%
    left_join(by_country_new, by = c("country", "continent")) %>%
    mutate(best_model = case_when(
      r_sq_version == "r.squared2" ~ model3,
      r_sq_version == "r.squared1" ~ model2,
      r_sq_version == "r.squared"  ~ model1,
    )) %>%
    select(-c(model1:model3)) %>%
    mutate(prediction = map2(best_model, data.y,
                             ~broom::augment(.x, newdata = .y))) -> new_fits

然后我们可以看到这些预测如何看起来像是原始数据中建立的时间趋势的延续(由于我们新数据中人口和 gdp 的变化,还有一些其他变化)。

new_predictions <-   new_fits %>%
    filter(country == "Afghanistan") %>%
    select(prediction) %>%
    unnest_wider(prediction) %>%
    flatten_dfr() %>%
  rename(lifeExp = ".fitted")

gapminder %>%
  filter(country == "Afghanistan") %>%
  bind_rows(new_predictions) %>%
  ggplot(aes(year, lifeExp)) +
  geom_point() +
  labs(title = "Afghanistan extrapolated lifeExp")