R语言做机器学习的当红辣子鸡R包:mlr3
和tidymodels
,之前用十几篇推文详细介绍过mlr3
mlr3:开篇
mlr3:基础使用
mlr3:模型评价
mlr3:模型比较
mlr3:超参数调优
mlr3:嵌套重抽样
mlr3:特征选择
mlr3:pipelines
mlr3:技术细节
mlr3:模型解释
mlr3实战:决策树和xgboost预测房价
今天学习下tidymodels
的使用,其实之前在介绍临床预测模型时已经用过这个包了:使用tidymodels搞定二分类资料多个模型评价和比较
但是对于很多没接触过这个包的朋友来说有些地方还是不好理解,所以今天专门写一篇推文介绍下tidymodels
的一些使用细节,帮助大家更上一层楼。
不得不说,比mlr3
简单多了!
目录:
- 设计理念
- 安装
- 基本使用
- 探索数据
- 模型选择
- 数据划分
- 数据预处理
- 建立workflow
- 选择重抽样方法
- 训练模型(无重抽样)
- 训练模型(有重抽样)
- 用于测试集
- 进阶
- 总结
设计理念
tidymodels
是max kuhn加入rstudio之后和Julia silge等人共同开发的机器学习R包,类似于mlr3
和caret
,也是一个整合包,只提供统一的API,让大家可以通过统一的语法调用R语言里各种现成的机器学习算法R包,并不发明新的算法。
这样做对用户来说最大的好处是不用记那么多R包的用法了,只需要记住tidymodels
一个包的用法及参数就够了。同时得益于tidyverse
系列的加持,在tidymodels
中进行的各种操作以及产生的各种结果都是遵循tidy
系列的设计理念的。所以非常有规律,很容易记住!
tidymodels
类似于tidyverse
,是一系列R包的合集,其中主要的包括:
parsnip
:提供统一的语法来选择模型(算法)recipes
:数据预处理rsample
:重抽样dials
:设置超参数tune
:调整超参数yardstick
:评价模型broom
:可以把各种模型的结果以整洁tibble格式返回,支持R语言所有内置模型!还有大部分第三方R包的模型!infer
:统计推断workflows
:联合数据预处理和算法
除此之外,还包括ggplot2/purrr/dplyr/tibble
等R包。
真正在用的时候并不需要刻意记住,只需加载tidymodels
就可得到全部~
因为和tidyverse
系列是一脉相承的,所以也是支持管道符的,这样的操作看起来非常的流畅,也比较容易理解,对于初学者来说比mlr3
那种面向对象的编程,简单多了。
但是一个很大的问题是速度,因为底层也是基于tibble
,所以速度没那么快,尤其是在调参的时候,非常慢,运算量一大就得好久时间才能出结果!
安装
目前发展还是很快,经常变更版本,所以时不时会遇到一些小问题,但总体来说瑕不掩瑜,学了不吃亏。
代码语言:javascript复制# 2选1
install.packages("tidymodels")
library("devtools")
install_github("tidymodels/tidymodels")
基本使用
基本使用步骤和大家像想象中的差不多:
- 选择算法(模型)
- 数据预处理
- 训练集建模
- 测试集看效果
在建模的过程中可能会同时出现重抽样、超参数调整等步骤,但基本步骤就是这样的。
代码语言:javascript复制library(tidyverse)
## ── Attaching packages ───────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.3.6 ✔ purrr 0.3.4
## ✔ tibble 3.1.7 ✔ dplyr 1.0.9
## ✔ tidyr 1.2.0 ✔ stringr 1.4.0
## ✔ readr 2.1.2 ✔ forcats 0.5.1
## ── Conflicts ──────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(tidymodels)
## ── Attaching packages ──────────────────────────── tidymodels 1.0.0 ──
## ✔ broom 1.0.0 ✔ rsample 1.0.0
## ✔ dials 1.0.0 ✔ tune 1.0.0
## ✔ infer 1.0.2 ✔ workflows 1.0.0
## ✔ modeldata 1.0.0 ✔ workflowsets 1.0.0
## ✔ parsnip 1.0.0 ✔ yardstick 1.0.0
## ✔ recipes 1.0.1
## ── Conflicts ─────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
tidymodels_prefer() # 防止函数冲突
探索数据
我们用一个结果变量是二分类变量的数据集来做一个简单的演示。这个数据集是关于大人住旅馆会不会带孩子一起。。。
代码语言:javascript复制rm(list = ls())
load(file = "../datasets/hotels_df.rdata")
简单查看一下数据:
代码语言:javascript复制hotels_df |> glimpse()
## Rows: 75,166
## Columns: 10
## $ children <fct> none, none, none, none, none, none, none, …
## $ hotel <fct> Resort Hotel, Resort Hotel, Resort Hotel, …
## $ arrival_date_month <fct> July, July, July, July, July, July, July, …
## $ meal <fct> BB, BB, BB, BB, BB, BB, BB, FB, HB, BB, HB…
## $ adr <dbl> 0.00, 0.00, 75.00, 75.00, 98.00, 98.00, 10…
## $ adults <dbl> 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
## $ required_car_parking_spaces <fct> none, none, none, none, none, none, none, …
## $ total_of_special_requests <dbl> 0, 0, 0, 0, 1, 1, 0, 1, 0, 3, 1, 0, 3, 0, …
## $ stays_in_week_nights <dbl> 0, 0, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, …
## $ stays_in_weekend_nights <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
这个数据一共10列,75166行,其中children
这一列是结果变量,是二分类的,其余9列都是预测变量。
我们的目的是用9列预测变量预测结果变量(感觉好绕啊)。。
代码语言:javascript复制hotels_df |> count(children)
## # A tibble: 2 × 2
## children n
## <fct> <int>
## 1 children 6073
## 2 none 69093
可以看到结果变量两种分类很不均衡,差了10倍多!
模型选择
模型选择的部分需要大家记住tidymodels
里面的一些名字,例如,对于决策树就是decision_tree()
,大家可以去这个网址[1]查看所有支持的模型以及它们在tidymodels
中的名字。
模型选择在这里就是3步走:
- 选择模型
- 使用哪个R包
- 回归还是分类(还有其他的自己看)
tree_spec <- decision_tree() |>
set_engine("rpart") |>
set_mode("classification")
当然你如果没有其他需求,可以这样写:
代码语言:javascript复制tree_spec <- decision_tree(engine = "rpart",mode = "classification")
效果一模一样!
大家都知道很多算法都是有超参数的,R里面有很多R包都可以实现同一种算法,但是支持的超参数却不一样!
所以,对于一些R包都有的超参数,大家可以把超参数写在选择模型这一步,对于一些R包特有的超参数(算法本身有但是其他包不支持)就要写在set_engine()
这里面
就像下面这样:
代码语言:javascript复制rf_spec <- rand_forest(trees = 1000, min_n = 5) |> # 这两个参数大家都有
set_engine("ranger", verbose = TRUE) |> # verbose参数只有ranger有,其他做随机森林的R包没有
set_mode("classification")
数据划分
在tidymodels
中数据划分非常简单。
set.seed(12) # 划分数据是随机的,设置种子数方便复现
hotel_split <- hotels_df |>
initial_split(prop = 0.7) # 需要根据某个变量分层只要加 strata = xxx即可
hotel_train <- training(hotel_split) # 训练集
hotel_test <- testing(hotel_split) # 测试集
这个是最常用的划分方法,还有很多,包括时间序列的划分等,大家可以自行学习。
划分好的数据长这样:
代码语言:javascript复制hotel_train |> glimpse()
## Rows: 52,616
## Columns: 10
## $ children <fct> none, none, none, none, none, none, none, …
## $ hotel <fct> Resort Hotel, City Hotel, City Hotel, City…
## $ arrival_date_month <fct> December, March, April, July, August, June…
## $ meal <fct> BB, HB, SC, BB, SC, SC, HB, BB, BB, HB, BB…
## $ adr <dbl> 43.57, 53.50, 0.00, 139.51, 94.50, 72.25, …
## $ adults <dbl> 1, 2, 0, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, …
## $ required_car_parking_spaces <fct> none, none, none, none, none, none, none, …
## $ total_of_special_requests <dbl> 2, 1, 1, 1, 1, 0, 0, 1, 2, 0, 1, 0, 0, 1, …
## $ stays_in_week_nights <dbl> 3, 1, 3, 3, 2, 3, 5, 5, 0, 1, 2, 3, 2, 1, …
## $ stays_in_weekend_nights <dbl> 0, 2, 2, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 0, …
数据预处理
为了让结果更准确,所以我们需要一些数据预处理步骤。
首先就是这个结果变量的类不平衡,我们可以用downsample
的方式解决,然后对于预测变量,我们需要对分类变量做哑变量处理,去除近零方差变量,还要对数值型变量标准化!
这个事情在tidymodels
中是这样操作的:
hotel_rec <- recipe(children ~ ., data = hotel_train) |>
themis::step_downsample(children) |>
step_dummy(all_nominal(), -all_outcomes()) |>
step_zv(all_numeric()) |>
step_normalize(all_numeric()) |>
prep() # 最后一定别忘记这个
看起来非常舒服,简单易懂,一步一步的下来即可,并且采用了tidyselect
的做法,支持all_nominal()
这种选择语法,非常方便的选择想要执行操作的列。
数据预处理之后,其实你不用把处理过的数据单独拿出来,就像之前介绍过的mlr3
一样,可以直接进行到下一步训练模型,但是考虑到有些人就是要看到数据,你可以这样操作:
# 提取处理好的训练集和测试集
train_proc <- bake(hotel_rec, new_data = NULL) # 训练集
test_proc <- bake(hotel_rec, new_data = hotel_test) # 测试集
train_proc |> glimpse()
## Rows: 8,486
## Columns: 23
## $ adr <dbl> -0.27844916, 1.95719142, 0.4760089…
## $ adults <dbl> 0.2314058, 0.2314058, 0.2314058, 0…
## $ total_of_special_requests <dbl> 0.108877, 0.108877, 0.108877, -0.9…
## $ stays_in_week_nights <dbl> 1.3026010, -0.8268292, -1.3591868,…
## $ stays_in_weekend_nights <dbl> -0.99994313, -0.99994313, 1.038374…
## $ children <fct> children, children, children, chil…
## $ hotel_Resort.Hotel <dbl> -0.830067, -0.830067, -0.830067, -…
## $ arrival_date_month_August <dbl> -0.4522368, -0.4522368, -0.4522368…
## $ arrival_date_month_December <dbl> -0.2502043, -0.2502043, -0.2502043…
## $ arrival_date_month_February <dbl> -0.2735849, -0.2735849, -0.2735849…
## $ arrival_date_month_January <dbl> -0.2244361, -0.2244361, -0.2244361…
## $ arrival_date_month_July <dbl> -0.4044283, 2.4723349, 2.4723349, …
## $ arrival_date_month_June <dbl> -0.2853468, -0.2853468, -0.2853468…
## $ arrival_date_month_March <dbl> -0.2785286, -0.2785286, -0.2785286…
## $ arrival_date_month_May <dbl> -0.2951308, -0.2951308, -0.2951308…
## $ arrival_date_month_November <dbl> -0.222693, -0.222693, -0.222693, -…
## $ arrival_date_month_October <dbl> -0.2948949, -0.2948949, -0.2948949…
## $ arrival_date_month_September <dbl> -0.2760646, -0.2760646, -0.2760646…
## $ meal_FB <dbl> -0.08920343, -0.08920343, -0.08920…
## $ meal_HB <dbl> -0.4054139, -0.4054139, -0.4054139…
## $ meal_SC <dbl> -0.2440307, -0.2440307, -0.2440307…
## $ meal_Undefined <dbl> -0.09998221, -0.09998221, -0.09998…
## $ required_car_parking_spaces_parking <dbl> -0.4002771, 2.4979750, -0.4002771,…
建立workflow
“这一步并不是必须要,建议对于有数据预处理步骤的,用workflow,如果没有数据预处理步骤,不用这一步更简单!
tidymodels
中增加了一个workflow
函数,可以把模型选择和数据预处理这两部连接起来,形成一个对象,这个类似于mlr3
的pipeline,但是只做这一件事!
tree_wf <- workflow() |>
add_recipe(hotel_rec) |>
add_model(tree_spec)
这里有多种方式构造workflow
,但是一定要记住,add_model(xxxx)
这一步是必须的!
初次用这个的时候碰到很多问题,后来才发现,顺序、formula、variable等都是随便加就行,唯独add_model(xxxx)
这一步必不可少!
如果你熟练以后也可以这样写:
代码语言:javascript复制tree_wf <- workflow(preprocessor = hotel_rec,
spec = tree_spec
)
这个workflow
对象里面很多东西都是可以通过extract_xxx()
提取的,但其实没啥用,一般情况下我们都知道自己前面干了什么。。
tree_wf |>
extract_preprocessor()
## Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 9
##
## Training data contained 52616 data points and no missing data.
##
## Operations:
##
## Down-sampling based on children [trained]
## Dummy variables from hotel, arrival_date_month, meal, required_car_parking_spaces [trained]
## Zero variance filter removed <none> [trained]
## Centering and scaling for adr, adults, total_of_special_requests, stays_i... [trained]
tree_wf |>
extract_spec_parsnip()
## Decision Tree Model Specification (classification)
##
## Computational engine: rpart
选择重抽样方法
也是支持非常多的方法,常见的交叉验证,重复交叉验证,留一法,bootstrap,蒙特卡洛等,都是支持的。
所有支持的重抽样方法可以在这里[2]查看。
我们就选择一个简单的,10折交叉验证:
代码语言:javascript复制set.seed(123)
cv <- vfold_cv(hotel_train, v = 10)
训练模型(无重抽样)
如果没有任何重抽样方法,那就非常简单了,直接fit()
,然后再predict()
就行了。
给大家演示下:
代码语言:javascript复制## 建模
tree_fit <- tree_wf |>
fit(hotel_train)
# 测试集预测
tree_pred <- predict(tree_fit, hotel_test)
# 查看结果
head(tree_pred)
## # A tibble: 6 × 1
## .pred_class
## <fct>
## 1 none
## 2 none
## 3 children
## 4 none
## 5 none
## 6 none
如果是崭新的、没有结果变量的数据集,也是可以通过这种方式预测的:
代码语言:javascript复制# 构造一个没有结果变量的数据集
tmp <- hotel_test |>
select(-children) |>
slice_sample(n=5)
glimpse(tmp)
## Rows: 5
## Columns: 9
## $ hotel <fct> City Hotel, Resort Hotel, City Hotel, Reso…
## $ arrival_date_month <fct> May, October, February, June, May
## $ meal <fct> BB, BB, BB, HB, BB
## $ adr <dbl> 130.0, 46.5, 88.4, 88.7, 132.6
## $ adults <dbl> 1, 2, 2, 2, 1
## $ required_car_parking_spaces <fct> none, none, none, none, none
## $ total_of_special_requests <dbl> 0, 0, 0, 0, 0
## $ stays_in_week_nights <dbl> 2, 0, 4, 8, 4
## $ stays_in_weekend_nights <dbl> 0, 1, 0, 2, 2
预测结果只需要添加new_data = tmp
即可:
predict(tree_fit, new_data = tmp)
## # A tibble: 5 × 1
## .pred_class
## <fct>
## 1 children
## 2 none
## 3 none
## 4 none
## 5 children
得益于tidy
系列的理念,这个predict()
函数进行了很多优化。比如:
现在很多R包的predict()
用到的参数是不一样的:
所以用起来就很烦,经常不知道写什么,tidymodels
也进行了统一,对于二分类变量来说,就是两个选项:
type = "prob"
算概率type = "class"
算类别
预测的结果也是有规律的:
- 如果是数值型变量,那预测结果列名必定是
.pred
- 如果是二分类变量,那预测结果列名必定是
.pred_class
,- 如果你选择了计算概率(prob),那结果列名就是
.pred_你的第一个分类
,.pred_你的第二个分类
。
- 如果你选择了计算概率(prob),那结果列名就是
有了这个规律,用起来就方便多了。所以对于这种预测结果的评价,一般是和原来的真实结果结合起来,然后进行各种操作:
代码语言:javascript复制tree_pred <- select(hotel_test, children) %>%
bind_cols(predict(tree_fit, hotel_test, type = "prob")) %>%
bind_cols(predict(tree_fit, hotel_test))
head(tree_pred)
## # A tibble: 6 × 4
## children .pred_children .pred_none .pred_class
## <fct> <dbl> <dbl> <fct>
## 1 none 0.251 0.749 none
## 2 none 0.251 0.749 none
## 3 none 0.583 0.417 children
## 4 none 0.251 0.749 none
## 5 none 0.251 0.749 none
## 6 none 0.251 0.749 none
得到这个结果之后,就可以进行各种模型评价了:
查看模型表现的操作也是非常遵循tidy
理念的,模型评价是通过yardstick
包实现的。比如下面这个AUC:
tree_pred %>% roc_auc(truth = children, estimate = .pred_children)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.739
想要看什么指标直接写名字,一般都能自动补全出来,所有支持的指标可以在这里[3]查看。
yardstick
的第一个参数永远是你的数据集(tree_pred),第二个参数永远是真实结果,第三个参数永远是预测结果!
可以说是非常的有规律了!
代码语言:javascript复制# 选择多种评价指标
metricsets <- metric_set(accuracy, mcc, f_meas, j_index)
tree_pred %>% metricsets(truth = children, estimate = .pred_class)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.692
## 2 mcc binary 0.260
## 3 f_meas binary 0.288
## 4 j_index binary 0.455
可视化结果也是一模一样的设计理念:
代码语言:javascript复制tree_pred %>% roc_curve(truth = children, estimate = .pred_children) %>%
autoplot()
ROC
训练模型(有重抽样)
不过我们是有交叉验证这一步的,下面就来演示~
在训练集中训练模型,因为这个算法不复杂,我们也没进行特别复杂的操作,所以还是很快的,在我电脑上大概2秒钟。。。
代码语言:javascript复制# 控制计算过程的一些设置
keep_pred <- control_resamples(save_pred = T, verbose = T)
set.seed(456)
library(doParallel)
cl <- makePSOCKcluster(12) # 加速,用12个线程
registerDoParallel(cl)
tree_res <- fit_resamples(tree_wf,
resamples = cv,
control = keep_pred)
## i Fold01: preprocessor 1/1
## ✓ Fold01: preprocessor 1/1
## i Fold01: preprocessor 1/1, model 1/1
## ✓ Fold01: preprocessor 1/1, model 1/1
## i Fold01: preprocessor 1/1, model 1/1 (predictions)
## i Fold02: preprocessor 1/1
## ✓ Fold02: preprocessor 1/1
## i Fold02: preprocessor 1/1, model 1/1
## ✓ Fold02: preprocessor 1/1, model 1/1
## i Fold02: preprocessor 1/1, model 1/1 (predictions)
## i Fold03: preprocessor 1/1
## ✓ Fold03: preprocessor 1/1
## i Fold03: preprocessor 1/1, model 1/1
## ✓ Fold03: preprocessor 1/1, model 1/1
## i Fold03: preprocessor 1/1, model 1/1 (predictions)
## i Fold04: preprocessor 1/1
## ✓ Fold04: preprocessor 1/1
## i Fold04: preprocessor 1/1, model 1/1
## ✓ Fold04: preprocessor 1/1, model 1/1
## i Fold04: preprocessor 1/1, model 1/1 (predictions)
## i Fold05: preprocessor 1/1
## ✓ Fold05: preprocessor 1/1
## i Fold05: preprocessor 1/1, model 1/1
## ✓ Fold05: preprocessor 1/1, model 1/1
## i Fold05: preprocessor 1/1, model 1/1 (predictions)
## i Fold06: preprocessor 1/1
## ✓ Fold06: preprocessor 1/1
## i Fold06: preprocessor 1/1, model 1/1
## ✓ Fold06: preprocessor 1/1, model 1/1
## i Fold06: preprocessor 1/1, model 1/1 (predictions)
## i Fold07: preprocessor 1/1
## ✓ Fold07: preprocessor 1/1
## i Fold07: preprocessor 1/1, model 1/1
## ✓ Fold07: preprocessor 1/1, model 1/1
## i Fold07: preprocessor 1/1, model 1/1 (predictions)
## i Fold08: preprocessor 1/1
## ✓ Fold08: preprocessor 1/1
## i Fold08: preprocessor 1/1, model 1/1
## ✓ Fold08: preprocessor 1/1, model 1/1
## i Fold08: preprocessor 1/1, model 1/1 (predictions)
## i Fold09: preprocessor 1/1
## ✓ Fold09: preprocessor 1/1
## i Fold09: preprocessor 1/1, model 1/1
## ✓ Fold09: preprocessor 1/1, model 1/1
## i Fold09: preprocessor 1/1, model 1/1 (predictions)
## i Fold10: preprocessor 1/1
## ✓ Fold10: preprocessor 1/1
## i Fold10: preprocessor 1/1, model 1/1
## ✓ Fold10: preprocessor 1/1, model 1/1
## i Fold10: preprocessor 1/1, model 1/1 (predictions)
stop(cl)
查看模型表现,不管你换什么模型、什么数据集,结果的列名都是这几个,比如.metric.estimator
这些,这也是tidy
的理念~
tree_res |>
collect_metrics()
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.727 10 0.00830 Preprocessor1_Model1
## 2 roc_auc binary 0.739 10 0.00322 Preprocessor1_Model1
想要查看每一折的表现也是可以的:
代码语言:javascript复制tree_res |>
collect_metrics(summarize=F)
## # A tibble: 20 × 5
## id .metric .estimator .estimate .config
## <chr> <chr> <chr> <dbl> <chr>
## 1 Fold01 accuracy binary 0.740 Preprocessor1_Model1
## 2 Fold01 roc_auc binary 0.743 Preprocessor1_Model1
## 3 Fold02 accuracy binary 0.757 Preprocessor1_Model1
## 4 Fold02 roc_auc binary 0.732 Preprocessor1_Model1
## 5 Fold03 accuracy binary 0.730 Preprocessor1_Model1
## 6 Fold03 roc_auc binary 0.727 Preprocessor1_Model1
## 7 Fold04 accuracy binary 0.675 Preprocessor1_Model1
## 8 Fold04 roc_auc binary 0.723 Preprocessor1_Model1
## 9 Fold05 accuracy binary 0.719 Preprocessor1_Model1
## 10 Fold05 roc_auc binary 0.755 Preprocessor1_Model1
## 11 Fold06 accuracy binary 0.720 Preprocessor1_Model1
## 12 Fold06 roc_auc binary 0.747 Preprocessor1_Model1
## 13 Fold07 accuracy binary 0.719 Preprocessor1_Model1
## 14 Fold07 roc_auc binary 0.747 Preprocessor1_Model1
## 15 Fold08 accuracy binary 0.753 Preprocessor1_Model1
## 16 Fold08 roc_auc binary 0.743 Preprocessor1_Model1
## 17 Fold09 accuracy binary 0.758 Preprocessor1_Model1
## 18 Fold09 roc_auc binary 0.743 Preprocessor1_Model1
## 19 Fold10 accuracy binary 0.702 Preprocessor1_Model1
## 20 Fold10 roc_auc binary 0.731 Preprocessor1_Model1
查看具体的结果,这个结果的列名也是很有规律的:
- 第一列永远是
id
, - 第二列是
.pred_你的第一个分类
, - 第三列是
.pred_你的第二个分类
, - 第四列是
.pred_xxx
,其中xxx
是你的结果变量的列名。
tree_res |>
collect_predictions()
## # A tibble: 52,616 × 7
## id .pred_children .pred_none .row .pred_class children .config
## <chr> <dbl> <dbl> <int> <fct> <fct> <chr>
## 1 Fold01 0.267 0.733 19 none none Preprocessor1_Mo…
## 2 Fold01 0.267 0.733 23 none none Preprocessor1_Mo…
## 3 Fold01 0.267 0.733 29 none none Preprocessor1_Mo…
## 4 Fold01 0.756 0.244 39 children children Preprocessor1_Mo…
## 5 Fold01 0.267 0.733 51 none none Preprocessor1_Mo…
## 6 Fold01 0.267 0.733 69 none none Preprocessor1_Mo…
## 7 Fold01 0.267 0.733 79 none none Preprocessor1_Mo…
## 8 Fold01 0.267 0.733 86 none none Preprocessor1_Mo…
## 9 Fold01 0.607 0.393 91 children none Preprocessor1_Mo…
## 10 Fold01 0.756 0.244 112 children none Preprocessor1_Mo…
## # … with 52,606 more rows
## # ℹ Use `print(n = ...)` to see more rows
如果你有调参的过程,这里又会多好几步,主要是用来选择合适的超参数,但是我们没有这一步。
用于测试集
注意这里不是直接predict()
哦,而是用last_fit()
这个函数,而且它的第二个参数不是测试集,而是hotel_split
!
tree_pred <- last_fit(tree_wf, hotel_split)
你想探索这个测试集的模型表现,也是和上面一样的:
代码语言:javascript复制tree_pred |> collect_metrics()
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.692 Preprocessor1_Model1
## 2 roc_auc binary 0.739 Preprocessor1_Model1
代码语言:javascript复制test_pred <- tree_pred |> collect_predictions()
head(test_pred)
## # A tibble: 6 × 7
## id .pred_children .pred_none .row .pred_class children .config
## <chr> <dbl> <dbl> <int> <fct> <fct> <chr>
## 1 train/test split 0.251 0.749 4 none none Preproc…
## 2 train/test split 0.251 0.749 8 none none Preproc…
## 3 train/test split 0.583 0.417 10 children none Preproc…
## 4 train/test split 0.251 0.749 12 none none Preproc…
## 5 train/test split 0.251 0.749 18 none none Preproc…
## 6 train/test split 0.251 0.749 21 none none Preproc…
roc_auc(test_pred, truth = children, .estimate = .pred_children)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.739
进阶
以上是关于tidymodels
的基础使用,大家在实际使用中经常会遇到更加复杂的情况,比如:多个模型的比较,多个模型在多个数据集并配合不同的预处理步骤,超参数调优等等。
关于多个模型比较的部分大家可以翻看我之前的推文:
使用tidymodels搞定二分类资料多个模型评价和比较 使用workflow一次完成多个模型的评价和比较
另外,还可以去我的个人博客:https://www.yuque.com/ayueme , 查看更多内容,我的博客里给出了非常多tidymodels
使用的例子,这些内容目前还没有搬到公众号上来,可以帮助大家更进一步了解这个包。
总结
总体来看,tidymodels
在统一使用方式方面做的非常棒,各个步骤中都有tidy
理念的影子,这样一旦你熟悉了其基本语法,使用起来是很舒服的,因为代码基本不用变,连列名都是固定的!
有点难度的地方是数据预处理步骤,因为太多了,所有的预处理步骤大家可以去这里[4]看。
另外,对于超参数调优的部分感觉不如mlr3
做得好,很多超参数的名字、类型、取值等很难记住,并且没有明确给出查看这些信息的函数,经常要不断的用?xxx
来看帮助文档。。。
还有一个就是速度,基于tibble
,并且各种fit_xxx()
函数也是基于purrr
包,这就导致它速度一般。但是目前我还没接触到需要好几个小时的数据,一般也就顶多半小时!
如果你是新手,建议你先学tidymodels
,因为简单,mlr3
的R6语法太反人类了。。。
今日示例数据已上传QQ群,需要的加群自取即可↓
参考资料
[1]
支持的模型: https://www.tidymodels.org/find/parsnip/
[2]
重抽样方法: https://rsample.tidymodels.org/reference/index.html
[3]
评价指标: https://yardstick.tidymodels.org/reference/index.html
[4]
预处理: https://recipes.tidymodels.org/reference/index.html