tidymodels用于机器学习的一些使用细节

2022-11-15 11:29:31 浏览数 (1)

R语言做机器学习的当红辣子鸡R包:mlr3tidymodels,之前用十几篇推文详细介绍过mlr3

mlr3:开篇

mlr3:基础使用

mlr3:模型评价

mlr3:模型比较

mlr3:超参数调优

mlr3:嵌套重抽样

mlr3:特征选择

mlr3:pipelines

mlr3:技术细节

mlr3:模型解释

mlr3实战:决策树和xgboost预测房价

今天学习下tidymodels的使用,其实之前在介绍临床预测模型时已经用过这个包了:使用tidymodels搞定二分类资料多个模型评价和比较

但是对于很多没接触过这个包的朋友来说有些地方还是不好理解,所以今天专门写一篇推文介绍下tidymodels的一些使用细节,帮助大家更上一层楼。

不得不说,比mlr3简单多了!

目录:

  • 设计理念
  • 安装
  • 基本使用
    • 探索数据
    • 模型选择
    • 数据划分
    • 数据预处理
    • 建立workflow
    • 选择重抽样方法
    • 训练模型(无重抽样)
    • 训练模型(有重抽样)
    • 用于测试集
  • 进阶
  • 总结

设计理念

tidymodels是max kuhn加入rstudio之后和Julia silge等人共同开发的机器学习R包,类似于mlr3caret,也是一个整合包,只提供统一的API,让大家可以通过统一的语法调用R语言里各种现成的机器学习算法R包,并不发明新的算法。

这样做对用户来说最大的好处是不用记那么多R包的用法了,只需要记住tidymodels一个包的用法及参数就够了。同时得益于tidyverse系列的加持,在tidymodels中进行的各种操作以及产生的各种结果都是遵循tidy系列的设计理念的。所以非常有规律,很容易记住!

tidymodels类似于tidyverse,是一系列R包的合集,其中主要的包括:

  • parsnip:提供统一的语法来选择模型(算法)
  • recipes:数据预处理
  • rsample:重抽样
  • dials:设置超参数
  • tune:调整超参数
  • yardstick:评价模型
  • broom:可以把各种模型的结果以整洁tibble格式返回,支持R语言所有内置模型!还有大部分第三方R包的模型!
  • infer:统计推断
  • workflows:联合数据预处理和算法

除此之外,还包括ggplot2/purrr/dplyr/tibble等R包。

真正在用的时候并不需要刻意记住,只需加载tidymodels就可得到全部~

因为和tidyverse系列是一脉相承的,所以也是支持管道符的,这样的操作看起来非常的流畅,也比较容易理解,对于初学者来说比mlr3那种面向对象的编程,简单多了。

但是一个很大的问题是速度,因为底层也是基于tibble,所以速度没那么快,尤其是在调参的时候,非常慢,运算量一大就得好久时间才能出结果!

安装

目前发展还是很快,经常变更版本,所以时不时会遇到一些小问题,但总体来说瑕不掩瑜,学了不吃亏。

代码语言:javascript复制
# 2选1
install.packages("tidymodels")

library("devtools")
install_github("tidymodels/tidymodels")

基本使用

基本使用步骤和大家像想象中的差不多:

  • 选择算法(模型)
  • 数据预处理
  • 训练集建模
  • 测试集看效果

在建模的过程中可能会同时出现重抽样、超参数调整等步骤,但基本步骤就是这样的。

代码语言:javascript复制
library(tidyverse)
## ── Attaching packages ───────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.3.6     ✔ purrr   0.3.4
## ✔ tibble  3.1.7     ✔ dplyr   1.0.9
## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
## ✔ readr   2.1.2     ✔ forcats 0.5.1
## ── Conflicts ──────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(tidymodels)
## ── Attaching packages ──────────────────────────── tidymodels 1.0.0 ──
## ✔ broom        1.0.0     ✔ rsample      1.0.0
## ✔ dials        1.0.0     ✔ tune         1.0.0
## ✔ infer        1.0.2     ✔ workflows    1.0.0
## ✔ modeldata    1.0.0     ✔ workflowsets 1.0.0
## ✔ parsnip      1.0.0     ✔ yardstick    1.0.0
## ✔ recipes      1.0.1     
## ── Conflicts ─────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
tidymodels_prefer() # 防止函数冲突

探索数据

我们用一个结果变量是二分类变量的数据集来做一个简单的演示。这个数据集是关于大人住旅馆会不会带孩子一起。。。

代码语言:javascript复制
rm(list = ls())
load(file = "../datasets/hotels_df.rdata")

简单查看一下数据:

代码语言:javascript复制
hotels_df |> glimpse()
## Rows: 75,166
## Columns: 10
## $ children                    <fct> none, none, none, none, none, none, none, …
## $ hotel                       <fct> Resort Hotel, Resort Hotel, Resort Hotel, …
## $ arrival_date_month          <fct> July, July, July, July, July, July, July, …
## $ meal                        <fct> BB, BB, BB, BB, BB, BB, BB, FB, HB, BB, HB…
## $ adr                         <dbl> 0.00, 0.00, 75.00, 75.00, 98.00, 98.00, 10…
## $ adults                      <dbl> 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
## $ required_car_parking_spaces <fct> none, none, none, none, none, none, none, …
## $ total_of_special_requests   <dbl> 0, 0, 0, 0, 1, 1, 0, 1, 0, 3, 1, 0, 3, 0, …
## $ stays_in_week_nights        <dbl> 0, 0, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, …
## $ stays_in_weekend_nights     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …

这个数据一共10列,75166行,其中children这一列是结果变量,是二分类的,其余9列都是预测变量。

我们的目的是用9列预测变量预测结果变量(感觉好绕啊)。。

代码语言:javascript复制
hotels_df |> count(children)
## # A tibble: 2 × 2
##   children     n
##   <fct>    <int>
## 1 children  6073
## 2 none     69093

可以看到结果变量两种分类很不均衡,差了10倍多!

模型选择

模型选择的部分需要大家记住tidymodels里面的一些名字,例如,对于决策树就是decision_tree(),大家可以去这个网址[1]查看所有支持的模型以及它们在tidymodels中的名字。

模型选择在这里就是3步走:

  • 选择模型
  • 使用哪个R包
  • 回归还是分类(还有其他的自己看)
代码语言:javascript复制
tree_spec <- decision_tree() |> 
  set_engine("rpart") |> 
  set_mode("classification")

当然你如果没有其他需求,可以这样写:

代码语言:javascript复制
tree_spec <- decision_tree(engine = "rpart",mode = "classification")

效果一模一样!

大家都知道很多算法都是有超参数的,R里面有很多R包都可以实现同一种算法,但是支持的超参数却不一样!

所以,对于一些R包都有的超参数,大家可以把超参数写在选择模型这一步,对于一些R包特有的超参数(算法本身有但是其他包不支持)就要写在set_engine()这里面

就像下面这样:

代码语言:javascript复制
rf_spec <- rand_forest(trees = 1000, min_n = 5) |> # 这两个参数大家都有
  set_engine("ranger", verbose = TRUE) |> # verbose参数只有ranger有,其他做随机森林的R包没有
  set_mode("classification")

数据划分

tidymodels中数据划分非常简单。

代码语言:javascript复制
set.seed(12) # 划分数据是随机的,设置种子数方便复现

hotel_split <- hotels_df |>  
  initial_split(prop = 0.7) # 需要根据某个变量分层只要加 strata = xxx即可

hotel_train <- training(hotel_split) # 训练集
hotel_test <- testing(hotel_split) # 测试集

这个是最常用的划分方法,还有很多,包括时间序列的划分等,大家可以自行学习。

划分好的数据长这样:

代码语言:javascript复制
hotel_train |> glimpse()
## Rows: 52,616
## Columns: 10
## $ children                    <fct> none, none, none, none, none, none, none, …
## $ hotel                       <fct> Resort Hotel, City Hotel, City Hotel, City…
## $ arrival_date_month          <fct> December, March, April, July, August, June…
## $ meal                        <fct> BB, HB, SC, BB, SC, SC, HB, BB, BB, HB, BB…
## $ adr                         <dbl> 43.57, 53.50, 0.00, 139.51, 94.50, 72.25, …
## $ adults                      <dbl> 1, 2, 0, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, …
## $ required_car_parking_spaces <fct> none, none, none, none, none, none, none, …
## $ total_of_special_requests   <dbl> 2, 1, 1, 1, 1, 0, 0, 1, 2, 0, 1, 0, 0, 1, …
## $ stays_in_week_nights        <dbl> 3, 1, 3, 3, 2, 3, 5, 5, 0, 1, 2, 3, 2, 1, …
## $ stays_in_weekend_nights     <dbl> 0, 2, 2, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 0, …

数据预处理

为了让结果更准确,所以我们需要一些数据预处理步骤。

首先就是这个结果变量的类不平衡,我们可以用downsample的方式解决,然后对于预测变量,我们需要对分类变量做哑变量处理,去除近零方差变量,还要对数值型变量标准化!

这个事情在tidymodels中是这样操作的:

代码语言:javascript复制
hotel_rec <- recipe(children ~ ., data = hotel_train) |> 
  themis::step_downsample(children) |>
  step_dummy(all_nominal(), -all_outcomes()) |>
  step_zv(all_numeric()) |>
  step_normalize(all_numeric()) |>
  prep() # 最后一定别忘记这个

看起来非常舒服,简单易懂,一步一步的下来即可,并且采用了tidyselect的做法,支持all_nominal()这种选择语法,非常方便的选择想要执行操作的列。

数据预处理之后,其实你不用把处理过的数据单独拿出来,就像之前介绍过的mlr3一样,可以直接进行到下一步训练模型,但是考虑到有些人就是要看到数据,你可以这样操作:

代码语言:javascript复制
# 提取处理好的训练集和测试集
train_proc <- bake(hotel_rec, new_data = NULL) # 训练集
test_proc <- bake(hotel_rec, new_data = hotel_test) # 测试集

train_proc |> glimpse()
## Rows: 8,486
## Columns: 23
## $ adr                                 <dbl> -0.27844916, 1.95719142, 0.4760089…
## $ adults                              <dbl> 0.2314058, 0.2314058, 0.2314058, 0…
## $ total_of_special_requests           <dbl> 0.108877, 0.108877, 0.108877, -0.9…
## $ stays_in_week_nights                <dbl> 1.3026010, -0.8268292, -1.3591868,…
## $ stays_in_weekend_nights             <dbl> -0.99994313, -0.99994313, 1.038374…
## $ children                            <fct> children, children, children, chil…
## $ hotel_Resort.Hotel                  <dbl> -0.830067, -0.830067, -0.830067, -…
## $ arrival_date_month_August           <dbl> -0.4522368, -0.4522368, -0.4522368…
## $ arrival_date_month_December         <dbl> -0.2502043, -0.2502043, -0.2502043…
## $ arrival_date_month_February         <dbl> -0.2735849, -0.2735849, -0.2735849…
## $ arrival_date_month_January          <dbl> -0.2244361, -0.2244361, -0.2244361…
## $ arrival_date_month_July             <dbl> -0.4044283, 2.4723349, 2.4723349, …
## $ arrival_date_month_June             <dbl> -0.2853468, -0.2853468, -0.2853468…
## $ arrival_date_month_March            <dbl> -0.2785286, -0.2785286, -0.2785286…
## $ arrival_date_month_May              <dbl> -0.2951308, -0.2951308, -0.2951308…
## $ arrival_date_month_November         <dbl> -0.222693, -0.222693, -0.222693, -…
## $ arrival_date_month_October          <dbl> -0.2948949, -0.2948949, -0.2948949…
## $ arrival_date_month_September        <dbl> -0.2760646, -0.2760646, -0.2760646…
## $ meal_FB                             <dbl> -0.08920343, -0.08920343, -0.08920…
## $ meal_HB                             <dbl> -0.4054139, -0.4054139, -0.4054139…
## $ meal_SC                             <dbl> -0.2440307, -0.2440307, -0.2440307…
## $ meal_Undefined                      <dbl> -0.09998221, -0.09998221, -0.09998…
## $ required_car_parking_spaces_parking <dbl> -0.4002771, 2.4979750, -0.4002771,…

建立workflow

“这一步并不是必须要,建议对于有数据预处理步骤的,用workflow,如果没有数据预处理步骤,不用这一步更简单!

tidymodels中增加了一个workflow函数,可以把模型选择和数据预处理这两部连接起来,形成一个对象,这个类似于mlr3的pipeline,但是只做这一件事!

代码语言:javascript复制
tree_wf <- workflow() |> 
  add_recipe(hotel_rec) |> 
  add_model(tree_spec) 

这里有多种方式构造workflow,但是一定要记住,add_model(xxxx)这一步是必须的!

初次用这个的时候碰到很多问题,后来才发现,顺序、formula、variable等都是随便加就行,唯独add_model(xxxx)这一步必不可少!

如果你熟练以后也可以这样写:

代码语言:javascript复制
tree_wf <- workflow(preprocessor = hotel_rec,
                    spec = tree_spec
                    )

这个workflow对象里面很多东西都是可以通过extract_xxx()提取的,但其实没啥用,一般情况下我们都知道自己前面干了什么。。

代码语言:javascript复制
tree_wf |> 
  extract_preprocessor()
## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          9
## 
## Training data contained 52616 data points and no missing data.
## 
## Operations:
## 
## Down-sampling based on children [trained]
## Dummy variables from hotel, arrival_date_month, meal, required_car_parking_spaces [trained]
## Zero variance filter removed <none> [trained]
## Centering and scaling for adr, adults, total_of_special_requests, stays_i... [trained]

tree_wf |> 
  extract_spec_parsnip()
## Decision Tree Model Specification (classification)
## 
## Computational engine: rpart

选择重抽样方法

也是支持非常多的方法,常见的交叉验证,重复交叉验证,留一法,bootstrap,蒙特卡洛等,都是支持的。

所有支持的重抽样方法可以在这里[2]查看。

我们就选择一个简单的,10折交叉验证:

代码语言:javascript复制
set.seed(123)

cv <- vfold_cv(hotel_train, v = 10)

训练模型(无重抽样)

如果没有任何重抽样方法,那就非常简单了,直接fit(),然后再predict()就行了。

给大家演示下:

代码语言:javascript复制
## 建模
tree_fit <- tree_wf |> 
  fit(hotel_train)

# 测试集预测
tree_pred <- predict(tree_fit, hotel_test)

# 查看结果
head(tree_pred)
## # A tibble: 6 × 1
##   .pred_class
##   <fct>      
## 1 none       
## 2 none       
## 3 children   
## 4 none       
## 5 none       
## 6 none

如果是崭新的、没有结果变量的数据集,也是可以通过这种方式预测的:

代码语言:javascript复制
# 构造一个没有结果变量的数据集
tmp <- hotel_test |> 
  select(-children) |> 
  slice_sample(n=5)

glimpse(tmp)
## Rows: 5
## Columns: 9
## $ hotel                       <fct> City Hotel, Resort Hotel, City Hotel, Reso…
## $ arrival_date_month          <fct> May, October, February, June, May
## $ meal                        <fct> BB, BB, BB, HB, BB
## $ adr                         <dbl> 130.0, 46.5, 88.4, 88.7, 132.6
## $ adults                      <dbl> 1, 2, 2, 2, 1
## $ required_car_parking_spaces <fct> none, none, none, none, none
## $ total_of_special_requests   <dbl> 0, 0, 0, 0, 0
## $ stays_in_week_nights        <dbl> 2, 0, 4, 8, 4
## $ stays_in_weekend_nights     <dbl> 0, 1, 0, 2, 2

预测结果只需要添加new_data = tmp即可:

代码语言:javascript复制
predict(tree_fit, new_data = tmp)
## # A tibble: 5 × 1
##   .pred_class
##   <fct>      
## 1 children   
## 2 none       
## 3 none       
## 4 none       
## 5 children

得益于tidy系列的理念,这个predict()函数进行了很多优化。比如:

现在很多R包的predict()用到的参数是不一样的:

所以用起来就很烦,经常不知道写什么,tidymodels也进行了统一,对于二分类变量来说,就是两个选项:

  • type = "prob"算概率
  • type = "class"算类别

预测的结果也是有规律的:

  • 如果是数值型变量,那预测结果列名必定是.pred
  • 如果是二分类变量,那预测结果列名必定是.pred_class
    • 如果你选择了计算概率(prob),那结果列名就是.pred_你的第一个分类.pred_你的第二个分类

有了这个规律,用起来就方便多了。所以对于这种预测结果的评价,一般是和原来的真实结果结合起来,然后进行各种操作:

代码语言:javascript复制
tree_pred <- select(hotel_test, children) %>% 
  bind_cols(predict(tree_fit, hotel_test, type = "prob")) %>% 
  bind_cols(predict(tree_fit, hotel_test))

head(tree_pred)
## # A tibble: 6 × 4
##   children .pred_children .pred_none .pred_class
##   <fct>             <dbl>      <dbl> <fct>      
## 1 none              0.251      0.749 none       
## 2 none              0.251      0.749 none       
## 3 none              0.583      0.417 children   
## 4 none              0.251      0.749 none       
## 5 none              0.251      0.749 none       
## 6 none              0.251      0.749 none

得到这个结果之后,就可以进行各种模型评价了:

查看模型表现的操作也是非常遵循tidy理念的,模型评价是通过yardstick包实现的。比如下面这个AUC:

代码语言:javascript复制
tree_pred %>% roc_auc(truth = children, estimate = .pred_children)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.739

想要看什么指标直接写名字,一般都能自动补全出来,所有支持的指标可以在这里[3]查看。

yardstick的第一个参数永远是你的数据集(tree_pred),第二个参数永远是真实结果,第三个参数永远是预测结果!

可以说是非常的有规律了!

代码语言:javascript复制
# 选择多种评价指标
metricsets <- metric_set(accuracy, mcc, f_meas, j_index)

tree_pred %>% metricsets(truth = children, estimate = .pred_class)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.692
## 2 mcc      binary         0.260
## 3 f_meas   binary         0.288
## 4 j_index  binary         0.455

可视化结果也是一模一样的设计理念:

代码语言:javascript复制
tree_pred %>% roc_curve(truth = children, estimate = .pred_children) %>% 
  autoplot()

ROC

训练模型(有重抽样)

不过我们是有交叉验证这一步的,下面就来演示~

在训练集中训练模型,因为这个算法不复杂,我们也没进行特别复杂的操作,所以还是很快的,在我电脑上大概2秒钟。。。

代码语言:javascript复制
# 控制计算过程的一些设置
keep_pred <- control_resamples(save_pred = T, verbose = T)

set.seed(456)

library(doParallel) 
cl <- makePSOCKcluster(12) # 加速,用12个线程
registerDoParallel(cl)

tree_res <- fit_resamples(tree_wf, 
                          resamples = cv, 
                          control = keep_pred)
## i Fold01: preprocessor 1/1
## ✓ Fold01: preprocessor 1/1
## i Fold01: preprocessor 1/1, model 1/1
## ✓ Fold01: preprocessor 1/1, model 1/1
## i Fold01: preprocessor 1/1, model 1/1 (predictions)
## i Fold02: preprocessor 1/1
## ✓ Fold02: preprocessor 1/1
## i Fold02: preprocessor 1/1, model 1/1
## ✓ Fold02: preprocessor 1/1, model 1/1
## i Fold02: preprocessor 1/1, model 1/1 (predictions)
## i Fold03: preprocessor 1/1
## ✓ Fold03: preprocessor 1/1
## i Fold03: preprocessor 1/1, model 1/1
## ✓ Fold03: preprocessor 1/1, model 1/1
## i Fold03: preprocessor 1/1, model 1/1 (predictions)
## i Fold04: preprocessor 1/1
## ✓ Fold04: preprocessor 1/1
## i Fold04: preprocessor 1/1, model 1/1
## ✓ Fold04: preprocessor 1/1, model 1/1
## i Fold04: preprocessor 1/1, model 1/1 (predictions)
## i Fold05: preprocessor 1/1
## ✓ Fold05: preprocessor 1/1
## i Fold05: preprocessor 1/1, model 1/1
## ✓ Fold05: preprocessor 1/1, model 1/1
## i Fold05: preprocessor 1/1, model 1/1 (predictions)
## i Fold06: preprocessor 1/1
## ✓ Fold06: preprocessor 1/1
## i Fold06: preprocessor 1/1, model 1/1
## ✓ Fold06: preprocessor 1/1, model 1/1
## i Fold06: preprocessor 1/1, model 1/1 (predictions)
## i Fold07: preprocessor 1/1
## ✓ Fold07: preprocessor 1/1
## i Fold07: preprocessor 1/1, model 1/1
## ✓ Fold07: preprocessor 1/1, model 1/1
## i Fold07: preprocessor 1/1, model 1/1 (predictions)
## i Fold08: preprocessor 1/1
## ✓ Fold08: preprocessor 1/1
## i Fold08: preprocessor 1/1, model 1/1
## ✓ Fold08: preprocessor 1/1, model 1/1
## i Fold08: preprocessor 1/1, model 1/1 (predictions)
## i Fold09: preprocessor 1/1
## ✓ Fold09: preprocessor 1/1
## i Fold09: preprocessor 1/1, model 1/1
## ✓ Fold09: preprocessor 1/1, model 1/1
## i Fold09: preprocessor 1/1, model 1/1 (predictions)
## i Fold10: preprocessor 1/1
## ✓ Fold10: preprocessor 1/1
## i Fold10: preprocessor 1/1, model 1/1
## ✓ Fold10: preprocessor 1/1, model 1/1
## i Fold10: preprocessor 1/1, model 1/1 (predictions)
stop(cl)

查看模型表现,不管你换什么模型、什么数据集,结果的列名都是这几个,比如.metric.estimator这些,这也是tidy的理念~

代码语言:javascript复制
tree_res |>  
  collect_metrics()
## # A tibble: 2 × 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.727    10 0.00830 Preprocessor1_Model1
## 2 roc_auc  binary     0.739    10 0.00322 Preprocessor1_Model1

想要查看每一折的表现也是可以的:

代码语言:javascript复制
tree_res |> 
  collect_metrics(summarize=F)
## # A tibble: 20 × 5
##    id     .metric  .estimator .estimate .config             
##    <chr>  <chr>    <chr>          <dbl> <chr>               
##  1 Fold01 accuracy binary         0.740 Preprocessor1_Model1
##  2 Fold01 roc_auc  binary         0.743 Preprocessor1_Model1
##  3 Fold02 accuracy binary         0.757 Preprocessor1_Model1
##  4 Fold02 roc_auc  binary         0.732 Preprocessor1_Model1
##  5 Fold03 accuracy binary         0.730 Preprocessor1_Model1
##  6 Fold03 roc_auc  binary         0.727 Preprocessor1_Model1
##  7 Fold04 accuracy binary         0.675 Preprocessor1_Model1
##  8 Fold04 roc_auc  binary         0.723 Preprocessor1_Model1
##  9 Fold05 accuracy binary         0.719 Preprocessor1_Model1
## 10 Fold05 roc_auc  binary         0.755 Preprocessor1_Model1
## 11 Fold06 accuracy binary         0.720 Preprocessor1_Model1
## 12 Fold06 roc_auc  binary         0.747 Preprocessor1_Model1
## 13 Fold07 accuracy binary         0.719 Preprocessor1_Model1
## 14 Fold07 roc_auc  binary         0.747 Preprocessor1_Model1
## 15 Fold08 accuracy binary         0.753 Preprocessor1_Model1
## 16 Fold08 roc_auc  binary         0.743 Preprocessor1_Model1
## 17 Fold09 accuracy binary         0.758 Preprocessor1_Model1
## 18 Fold09 roc_auc  binary         0.743 Preprocessor1_Model1
## 19 Fold10 accuracy binary         0.702 Preprocessor1_Model1
## 20 Fold10 roc_auc  binary         0.731 Preprocessor1_Model1

查看具体的结果,这个结果的列名也是很有规律的:

  • 第一列永远是id
  • 第二列是.pred_你的第一个分类
  • 第三列是.pred_你的第二个分类
  • 第四列是.pred_xxx,其中xxx是你的结果变量的列名。
代码语言:javascript复制
tree_res |> 
  collect_predictions()
## # A tibble: 52,616 × 7
##    id     .pred_children .pred_none  .row .pred_class children .config          
##    <chr>           <dbl>      <dbl> <int> <fct>       <fct>    <chr>            
##  1 Fold01          0.267      0.733    19 none        none     Preprocessor1_Mo…
##  2 Fold01          0.267      0.733    23 none        none     Preprocessor1_Mo…
##  3 Fold01          0.267      0.733    29 none        none     Preprocessor1_Mo…
##  4 Fold01          0.756      0.244    39 children    children Preprocessor1_Mo…
##  5 Fold01          0.267      0.733    51 none        none     Preprocessor1_Mo…
##  6 Fold01          0.267      0.733    69 none        none     Preprocessor1_Mo…
##  7 Fold01          0.267      0.733    79 none        none     Preprocessor1_Mo…
##  8 Fold01          0.267      0.733    86 none        none     Preprocessor1_Mo…
##  9 Fold01          0.607      0.393    91 children    none     Preprocessor1_Mo…
## 10 Fold01          0.756      0.244   112 children    none     Preprocessor1_Mo…
## # … with 52,606 more rows
## # ℹ Use `print(n = ...)` to see more rows

如果你有调参的过程,这里又会多好几步,主要是用来选择合适的超参数,但是我们没有这一步。

用于测试集

注意这里不是直接predict()哦,而是用last_fit()这个函数,而且它的第二个参数不是测试集,而是hotel_split

代码语言:javascript复制
tree_pred <- last_fit(tree_wf, hotel_split)

你想探索这个测试集的模型表现,也是和上面一样的:

代码语言:javascript复制
tree_pred |> collect_metrics()
## # A tibble: 2 × 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.692 Preprocessor1_Model1
## 2 roc_auc  binary         0.739 Preprocessor1_Model1
代码语言:javascript复制
test_pred <- tree_pred |> collect_predictions()
head(test_pred)
## # A tibble: 6 × 7
##   id               .pred_children .pred_none  .row .pred_class children .config 
##   <chr>                     <dbl>      <dbl> <int> <fct>       <fct>    <chr>   
## 1 train/test split          0.251      0.749     4 none        none     Preproc…
## 2 train/test split          0.251      0.749     8 none        none     Preproc…
## 3 train/test split          0.583      0.417    10 children    none     Preproc…
## 4 train/test split          0.251      0.749    12 none        none     Preproc…
## 5 train/test split          0.251      0.749    18 none        none     Preproc…
## 6 train/test split          0.251      0.749    21 none        none     Preproc…
roc_auc(test_pred, truth = children, .estimate = .pred_children)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.739

进阶

以上是关于tidymodels的基础使用,大家在实际使用中经常会遇到更加复杂的情况,比如:多个模型的比较,多个模型在多个数据集并配合不同的预处理步骤,超参数调优等等。

关于多个模型比较的部分大家可以翻看我之前的推文:

使用tidymodels搞定二分类资料多个模型评价和比较 使用workflow一次完成多个模型的评价和比较

另外,还可以去我的个人博客:https://www.yuque.com/ayueme , 查看更多内容,我的博客里给出了非常多tidymodels使用的例子,这些内容目前还没有搬到公众号上来,可以帮助大家更进一步了解这个包。

总结

总体来看,tidymodels在统一使用方式方面做的非常棒,各个步骤中都有tidy理念的影子,这样一旦你熟悉了其基本语法,使用起来是很舒服的,因为代码基本不用变,连列名都是固定的!

有点难度的地方是数据预处理步骤,因为太多了,所有的预处理步骤大家可以去这里[4]看。

另外,对于超参数调优的部分感觉不如mlr3做得好,很多超参数的名字、类型、取值等很难记住,并且没有明确给出查看这些信息的函数,经常要不断的用?xxx来看帮助文档。。。

还有一个就是速度,基于tibble,并且各种fit_xxx()函数也是基于purrr包,这就导致它速度一般。但是目前我还没接触到需要好几个小时的数据,一般也就顶多半小时!

如果你是新手,建议你先学tidymodels,因为简单,mlr3的R6语法太反人类了。。。

今日示例数据已上传QQ群,需要的加群自取即可↓

参考资料

[1]

支持的模型: https://www.tidymodels.org/find/parsnip/

[2]

重抽样方法: https://rsample.tidymodels.org/reference/index.html

[3]

评价指标: https://yardstick.tidymodels.org/reference/index.html

[4]

预处理: https://recipes.tidymodels.org/reference/index.html

0 人点赞