mlr3_重抽样

2021-02-05 16:39:39 浏览数 (1)

mlr3_重抽样

概述

mlr3中包含的重抽样方法

  • cross validation ("cv"):交叉验证
  • leave-one-out cross validation ("loo"):留一验证
  • repeated cross validation ("repeated_cv") :重复交叉验证
  • bootstrapping ("bootstrap"):bootstrap
  • subsampling ("subsampling"):下采样
  • holdout ("holdout"):相当于3:7的分割方式
  • in-sample resampling ("insample")
  • custom resampling ("custom"):自定义重抽样

设置任务

代码语言:javascript复制
task = tsk("iris")
learner = lrn("classif.rpart")
# 查看mlr的重抽样方法有哪些
as.data.table(mlr_resamplings)
##            key        params iters
## 1:   bootstrap repeats,ratio    30
## 2:      custom                   0
## 3:          cv         folds    10
## 4:     holdout         ratio     1
## 5:    insample                   1
## 6:         loo                  NA
## 7: repeated_cv repeats,folds   100
## 8: subsampling repeats,ratio    30

# 通过rsmp函数提取采样方法
resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

这里$is_instantiated是false,这表示,我们没有将采样方法设置再数据集中。同时这里默认的采样比例是0.6667,可以通过下面两种方式更改

代码语言:javascript复制
resampling$param_set$values = list(ratio = 0.8)
rsmp("holdout", ratio = 0.8)

实例化

通过instantiate函数对任务进行分组

代码语言:javascript复制
resampling = rsmp("cv", folds = 3L)
resampling$instantiate(task)
resampling$iters
## [1] 3
# 查看训练和测试集的id号
str(resampling$train_set(1))
##  int [1:100] 2 3 4 5 10 12 14 15 18 19 ...
str(resampling$test_set(1))
##  int [1:50] 7 9 13 16 17 21 22 25 35 37 ...

执行重抽样

将task、learner和resample组合起来形成一个新的对象,

代码语言:javascript复制
task = tsk("pima")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3L)
# 将三者组合起来
rr = resample(task, learner, resampling, store_models = TRUE)
print(rr)

## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

#通过aggregate函数将多个结果平均
rr$aggregate(msr("classif.ce"))
## classif.ce 
##     0.2721


# 查看每个模型的性能
rr$score(msr("classif.ce"))
##                 task task_id                   learner    learner_id
## 1: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
## 2: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
## 3: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
##            resampling resampling_id iteration              prediction
## 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
## 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
## 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
##    classif.ce
## 1:     0.3164
## 2:     0.2617
## 3:     0.2383

查看迭代结果

代码语言:javascript复制
# 查看错误和警告
rr$warnings
rr$errors
# 查看抽样策略
rr$resampling
# 产看迭代次数
rr$resampling$iters
# 查看第一测试集和训练集
str(rr$resampling$test_set(1))
str(rr$resampling$train_set(1))

# 查看指定的学习器
lrn = rr$learners[[1]]
lrn$model

# 提取预测结果;这里将所有预测整合再一个表中
rr$prediction() 
# 提取第一次迭代结果
rr$predictions()[[1]]

自定义抽样

自己选择样本的编号,进行抽样,傻子才这样做

代码语言:javascript复制
resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)),
  test = list(c(11:20, 61:70, 111:120))
)
resampling$iters

绘制结果

代码语言:javascript复制
library("mlr3viz")
autoplot(rr)

hu

绘制roc曲线

autoplot(rr, type = "roc")

结束语

对于重抽样的操作,建议在高性能的服务器上进行,或者测试数据较少或者特征较少的数据集。

love&peace

0 人点赞