mlr3_重抽样
概述
mlr3中包含的重抽样方法
- cross validation ("cv"):交叉验证
- leave-one-out cross validation ("loo"):留一验证
- repeated cross validation ("repeated_cv") :重复交叉验证
- bootstrapping ("bootstrap"):bootstrap
- subsampling ("subsampling"):下采样
- holdout ("holdout"):相当于3:7的分割方式
- in-sample resampling ("insample")
- custom resampling ("custom"):自定义重抽样
设置任务
代码语言:javascript复制task = tsk("iris")
learner = lrn("classif.rpart")
# 查看mlr的重抽样方法有哪些
as.data.table(mlr_resamplings)
## key params iters
## 1: bootstrap repeats,ratio 30
## 2: custom 0
## 3: cv folds 10
## 4: holdout ratio 1
## 5: insample 1
## 6: loo NA
## 7: repeated_cv repeats,folds 100
## 8: subsampling repeats,ratio 30
# 通过rsmp函数提取采样方法
resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
这里$is_instantiated
是false,这表示,我们没有将采样方法设置再数据集中。同时这里默认的采样比例是0.6667,可以通过下面两种方式更改
resampling$param_set$values = list(ratio = 0.8)
rsmp("holdout", ratio = 0.8)
实例化
通过instantiate函数对任务进行分组
代码语言:javascript复制resampling = rsmp("cv", folds = 3L)
resampling$instantiate(task)
resampling$iters
## [1] 3
# 查看训练和测试集的id号
str(resampling$train_set(1))
## int [1:100] 2 3 4 5 10 12 14 15 18 19 ...
str(resampling$test_set(1))
## int [1:50] 7 9 13 16 17 21 22 25 35 37 ...
执行重抽样
将task、learner和resample组合起来形成一个新的对象,
代码语言:javascript复制task = tsk("pima")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3L)
# 将三者组合起来
rr = resample(task, learner, resampling, store_models = TRUE)
print(rr)
## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
#通过aggregate函数将多个结果平均
rr$aggregate(msr("classif.ce"))
## classif.ce
## 0.2721
# 查看每个模型的性能
rr$score(msr("classif.ce"))
## task task_id learner learner_id
## 1: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## 2: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## 3: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## resampling resampling_id iteration prediction
## 1: <ResamplingCV[19]> cv 1 <PredictionClassif[19]>
## 2: <ResamplingCV[19]> cv 2 <PredictionClassif[19]>
## 3: <ResamplingCV[19]> cv 3 <PredictionClassif[19]>
## classif.ce
## 1: 0.3164
## 2: 0.2617
## 3: 0.2383
查看迭代结果
代码语言:javascript复制# 查看错误和警告
rr$warnings
rr$errors
# 查看抽样策略
rr$resampling
# 产看迭代次数
rr$resampling$iters
# 查看第一测试集和训练集
str(rr$resampling$test_set(1))
str(rr$resampling$train_set(1))
# 查看指定的学习器
lrn = rr$learners[[1]]
lrn$model
# 提取预测结果;这里将所有预测整合再一个表中
rr$prediction()
# 提取第一次迭代结果
rr$predictions()[[1]]
自定义抽样
自己选择样本的编号,进行抽样,傻子才这样做
代码语言:javascript复制resampling = rsmp("custom")
resampling$instantiate(task,
train = list(c(1:10, 51:60, 101:110)),
test = list(c(11:20, 61:70, 111:120))
)
resampling$iters
绘制结果
代码语言:javascript复制library("mlr3viz")
autoplot(rr)
hu
绘制roc曲线
autoplot(rr, type = "roc")
结束语
对于重抽样的操作,建议在高性能的服务器上进行,或者测试数据较少或者特征较少的数据集。
love&peace