mlr3_训练和测试

2021-02-05 16:39:11 浏览数 (1)

mlr3_训练和测试

概述

之前的章节中,我们已经建立了task和learner,接下来利用这两个R6对象,建立模型,并使用新的数据集对模型进行评估

建立task和learner

这里使用简单的tsk和lrn方法建立

代码语言:javascript复制
task = tsk("sonar")
learner = lrn("classif.rpart")

设置训练和测试数据

这里设置的其实是task里面数据的行数目

代码语言:javascript复制
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

训练learner

$model是learner中用来存储训练好的模型

代码语言:javascript复制
# 可以看到目前是没有模型训练好的
learner$model
## NULL

接下来使用任务来训练learner

代码语言:javascript复制
# 这里使用row_ids选择训练数据
learner$train(task, row_ids = train_set)
# 训练完成后查看模型

print(learner$model)

预测

使用剩余的数据进行预测 predict

代码语言:javascript复制
# 返回每一个个案的预测结果
prediction = learner$predict(task, row_ids = test_set)
## <PredictionClassif> for 42 observations:
##     row_id truth response
##          2     R        R
##          6     R        R
##         12     R        M
## ---                      
##        191     M        M
##        199     M        M
##        204     M        M

# 为了提取预测后的数据,最好的办法是转换为data.table
head(as.data.table(prediction))

# 同时,我们需要计算混淆矩阵

prediction$confusion
##         truth
## response  M  R
##        M 15  3
##        R  8 16

改变预测的类型

这个部分主要是计算每一种类型的概率,有时候用于roc曲线的绘制

代码语言:javascript复制
learner$predict_type = "prob"
# 重新训练
learner$train(task, row_ids = train_set)

# 重新预测
prediction = learner$predict(task, row_ids = test_set)
# 查看结果
head(as.data.table(prediction))
##    row_id truth response prob.M  prob.R
## 1:      2     R        R 0.2222 0.77778
## 2:      6     R        R 0.2222 0.77778
## 3:     12     R        M 0.9375 0.06250
## 4:     13     R        R 0.1429 0.85714
## 5:     30     R        R 0.2222 0.77778
## 6:     31     R        M 0.9535 0.04651

可以看到,里面出现了新的两列,用于描述各自的概率大小

绘制预测图

代码语言:javascript复制
library("mlr3viz")
task = tsk("sonar")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
# 绘制默认图
autoplot(prediction)
# 绘制roc图
autoplot(prediction, type = "roc")

对于回归任务

代码语言:javascript复制
library("mlr3viz")
library("mlr3learners")
task = tsk("mtcars")
learner = lrn("regr.lm")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

模型评估

mlr3 自带一系列的评估方法,如

代码语言:javascript复制
mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
##   regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
##   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
##   selected_features, time_both, time_predict, time_train

# 使用msr获取评估的方法,这里是准确率
measure = msr("classif.acc")
prediction$score(measure)

## classif.acc 
##       0.875

结束语

到这里基本上mlr3的主要内容都已经更新完毕,后面涉及冲抽样,模型优化等问题 love&peace

0 人点赞