使用workflow一次完成多个模型的评价和比较

2022-11-15 10:51:38 浏览数 (1)

前面给大家介绍了使用tidymodels搞定二分类资料的模型评价和比较。

简介的语法、统一的格式、优雅的操作,让人欲罢不能!

但是太费事儿了,同样的流程来了4遍,那要是选择10个模型,就得来10遍!无聊,非常的无聊。

所以个大家介绍简便方法,不用重复写代码,一次搞定多个模型!

本期目录:

  • 加载数据和R包
  • 数据预处理
  • 选择模型
  • 选择重抽样方法
  • 构建workflow
  • 运行模型
  • 查看结果
  • 可视化结果
  • 选择最好的模型用于测试集

加载数据和R包

首先还是加载数据和R包,和前面的一模一样的操作,数据也没变。

代码语言:javascript复制
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
library(kknn)
tidymodels_prefer()

all_plays <- read_rds("../000files/all_plays.rds")

set.seed(20220520)

split_pbp <- initial_split(all_plays, 0.75, strata = play_type)

train_data <- training(split_pbp)
test_data <- testing(split_pbp)

数据预处理

代码语言:javascript复制
pbp_rec <- recipe(play_type ~ ., data = train_data)  %>%
  step_rm(half_seconds_remaining,yards_gained, game_id) %>% 
  step_string2factor(posteam, defteam) %>%  
  step_corr(all_numeric(), threshold = 0.7) %>% 
  step_center(all_numeric()) %>%  
  step_zv(all_predictors())  

选择模型

直接选择4个模型,你想选几个都是可以的。

代码语言:javascript复制
lm_mod <- logistic_reg(mode = "classification",engine = "glm")
knn_mod <- nearest_neighbor(mode = "classification", engine = "kknn")
rf_mod <- rand_forest(mode = "classification", engine = "ranger")
tree_mod <- decision_tree(mode = "classification",engine = "rpart")

选择重抽样方法

代码语言:javascript复制
set.seed(20220520)

folds <- vfold_cv(train_data, v = 10)
folds
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [62082/6899]> Fold01
##  2 <split [62083/6898]> Fold02
##  3 <split [62083/6898]> Fold03
##  4 <split [62083/6898]> Fold04
##  5 <split [62083/6898]> Fold05
##  6 <split [62083/6898]> Fold06
##  7 <split [62083/6898]> Fold07
##  8 <split [62083/6898]> Fold08
##  9 <split [62083/6898]> Fold09
## 10 <split [62083/6898]> Fold10

构建workflow

这一步就是不用重复写代码的关键,把所有模型和数据预处理步骤自动连接起来。

代码语言:javascript复制
library(workflowsets)

four_mods <- workflow_set(list(rec = pbp_rec), 
                          list(lm = lm_mod,
                               knn = knn_mod,
                               rf = rf_mod,
                               tree = tree_mod
                               ),
                          cross = T
                          )
four_mods
## # A workflow set/tibble: 4 × 4
##   wflow_id info             option    result    
##   <chr>    <list>           <list>    <list>    
## 1 rec_lm   <tibble [1 × 4]> <opts[0]> <list [0]>
## 2 rec_knn  <tibble [1 × 4]> <opts[0]> <list [0]>
## 3 rec_rf   <tibble [1 × 4]> <opts[0]> <list [0]>
## 4 rec_tree <tibble [1 × 4]> <opts[0]> <list [0]>

运行模型

首先是一些运行过程中的参数设置:

代码语言:javascript复制
keep_pred <- control_resamples(save_pred = T, verbose = T)

然后就是运行4个模型(目前一直是在训练集中),我们给它加速一下:

代码语言:javascript复制
library(doParallel) 
## Loading required package: foreach
## 
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
## 
##     accumulate, when
## Loading required package: iterators
## Loading required package: parallel

cl <- makePSOCKcluster(12) # 加速,用12个线程
registerDoParallel(cl)

four_fits <- four_mods %>% 
  workflow_map("fit_resamples",
               seed = 0520,
               verbose = T,
               resamples = folds,
               control = keep_pred
               )
## i 1 of 4 resampling: rec_lm
## ✔ 1 of 4 resampling: rec_lm (18.4s)
## i 2 of 4 resampling: rec_knn
## ✔ 2 of 4 resampling: rec_knn (3m 51.9s)
## i 3 of 4 resampling: rec_rf
## ✔ 3 of 4 resampling: rec_rf (1m 15.6s)
## i 4 of 4 resampling: rec_tree
## ✔ 4 of 4 resampling: rec_tree (6.1s)

four_fits
## # A workflow set/tibble: 4 × 4
##   wflow_id info             option    result   
##   <chr>    <list>           <list>    <list>   
## 1 rec_lm   <tibble [1 × 4]> <opts[2]> <rsmp[ ]>
## 2 rec_knn  <tibble [1 × 4]> <opts[2]> <rsmp[ ]>
## 3 rec_rf   <tibble [1 × 4]> <opts[2]> <rsmp[ ]>
## 4 rec_tree <tibble [1 × 4]> <opts[2]> <rsmp[ ]>

stopCluster(cl)

需要很长时间!大家笔记本如果内存不够可能会失败哦~

查看结果

查看模型在训练集中的表现:

代码语言:javascript复制
collect_metrics(four_fits)
## # A tibble: 8 × 9
##   wflow_id .config          preproc model .metric .estimator  mean     n std_err
##   <chr>    <chr>            <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
## 1 rec_lm   Preprocessor1_M… recipe  logi… accura… binary     0.724    10 1.91e-3
## 2 rec_lm   Preprocessor1_M… recipe  logi… roc_auc binary     0.781    10 1.88e-3
## 3 rec_knn  Preprocessor1_M… recipe  near… accura… binary     0.671    10 7.31e-4
## 4 rec_knn  Preprocessor1_M… recipe  near… roc_auc binary     0.716    10 1.28e-3
## 5 rec_rf   Preprocessor1_M… recipe  rand… accura… binary     0.732    10 1.48e-3
## 6 rec_rf   Preprocessor1_M… recipe  rand… roc_auc binary     0.799    10 1.90e-3
## 7 rec_tree Preprocessor1_M… recipe  deci… accura… binary     0.720    10 1.97e-3
## 8 rec_tree Preprocessor1_M… recipe  deci… roc_auc binary     0.704    10 2.01e-3

查看每一个预测结果,这个就不运行了,毕竟好几万行,太多了。。。

代码语言:javascript复制
collect_predictions(four_fits)

可视化结果

直接可视化4个模型的结果,感觉比ROC曲线更好看,还给出了可信区间。

这个图可以自己用ggplot2语法修改。

代码语言:javascript复制
four_fits %>% autoplot(metric = "roc_auc") theme_bw()

image-20220704145235120

选择最好的模型用于测试集

选择表现最好的应用于测试集:

代码语言:javascript复制
rand_res <- last_fit(rf_mod,pbp_rec,split_pbp)

查看在测试集的模型表现:

代码语言:javascript复制
collect_metrics(rand_res) # test 中的模型表现

image-20220704144956748

使用其他指标查看模型表现:

代码语言:javascript复制
metricsets <- metric_set(accuracy, mcc, f_meas, j_index)

collect_predictions(rand_res) %>% 
  metricsets(truth = play_type, estimate = .pred_class)

image-20220704145017664

可视化结果,喜闻乐见的混淆矩阵:

代码语言:javascript复制
collect_predictions(rand_res) %>% 
  conf_mat(play_type,.pred_class) %>% 
  autoplot()

image-20220704145028522

喜闻乐见的ROC曲线:

代码语言:javascript复制
collect_predictions(rand_res) %>% 
  roc_curve(play_type,.pred_pass) %>% 
  autoplot()

image-20220704145041578

还有非常多曲线和评价指标可选,大家可以看我之前的介绍推文~

是不是很神奇呢,完美符合一次挑选多个模型的要求,且步骤清稀,代码美观,非常适合进行多个模型的比较。

0 人点赞