mlr3_训练和测试
概述
之前的章节中,我们已经建立了task和learner,接下来利用这两个R6对象,建立模型,并使用新的数据集对模型进行评估
建立task和learner
这里使用简单的tsk和lrn方法建立
代码语言:javascript复制task = tsk("sonar")
learner = lrn("classif.rpart")
设置训练和测试数据
这里设置的其实是task里面数据的行数目
代码语言:javascript复制train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)
训练learner
$model
是learner中用来存储训练好的模型
# 可以看到目前是没有模型训练好的
learner$model
## NULL
接下来使用任务来训练learner
代码语言:javascript复制# 这里使用row_ids选择训练数据
learner$train(task, row_ids = train_set)
# 训练完成后查看模型
print(learner$model)
预测
使用剩余的数据进行预测
predict
# 返回每一个个案的预测结果
prediction = learner$predict(task, row_ids = test_set)
## <PredictionClassif> for 42 observations:
## row_id truth response
## 2 R R
## 6 R R
## 12 R M
## ---
## 191 M M
## 199 M M
## 204 M M
# 为了提取预测后的数据,最好的办法是转换为data.table
head(as.data.table(prediction))
# 同时,我们需要计算混淆矩阵
prediction$confusion
## truth
## response M R
## M 15 3
## R 8 16
改变预测的类型
这个部分主要是计算每一种类型的概率,有时候用于roc曲线的绘制
代码语言:javascript复制learner$predict_type = "prob"
# 重新训练
learner$train(task, row_ids = train_set)
# 重新预测
prediction = learner$predict(task, row_ids = test_set)
# 查看结果
head(as.data.table(prediction))
## row_id truth response prob.M prob.R
## 1: 2 R R 0.2222 0.77778
## 2: 6 R R 0.2222 0.77778
## 3: 12 R M 0.9375 0.06250
## 4: 13 R R 0.1429 0.85714
## 5: 30 R R 0.2222 0.77778
## 6: 31 R M 0.9535 0.04651
可以看到,里面出现了新的两列,用于描述各自的概率大小
绘制预测图
代码语言:javascript复制library("mlr3viz")
task = tsk("sonar")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
# 绘制默认图
autoplot(prediction)
# 绘制roc图
autoplot(prediction, type = "roc")
对于回归任务
代码语言:javascript复制library("mlr3viz")
library("mlr3learners")
task = tsk("mtcars")
learner = lrn("regr.lm")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)
模型评估
mlr3 自带一系列的评估方法,如
代码语言:javascript复制mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
## classif.ppv, classif.prauc, classif.precision, classif.recall,
## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
## classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
## regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
## regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
## regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
## selected_features, time_both, time_predict, time_train
# 使用msr获取评估的方法,这里是准确率
measure = msr("classif.acc")
prediction$score(measure)
## classif.acc
## 0.875
结束语
到这里基本上mlr3的主要内容都已经更新完毕,后面涉及冲抽样,模型优化等问题
love&peace