之前已经给大家介绍了临床预测模型和机器学习中特征选择(变量选择)常见的方法分类:
- 机器学习中的特征选择(变量筛选)方法简介
今天就给大家演示过滤法在caret中的实现。
首先要理解过滤法,其实很简单,就是在建立模型前先根据一些标准把一些变量过滤掉,然后再建模。
举个简单的例子,假如你的结果变量是二分类,自变量是数值型,那么对于每一个自变量,我们都可以以结果变量为分组变量,对自变量做方差分析,如果一个自变量在两个类别(也就是两个组别)中没有统计学差异,那这个变量就可以删掉了,因为它在两种类别中没有差别,并不能帮我们判断一个样本到底属于哪种类别。
类似的还有t检验、卡方检验、等等,这些方法的选择在这里主要是根据预测变量和结果变量的类型。比如预测变量是二分类,结果变量也是二分类,此时就可以用卡方检验或者Fisher精确概率法等,如果预测变量是数值型而结果变量是二分类,就可以用方差分析、t检验等。对于学过医学统计学的人来说应该不会很难理解!
除此之外,还有其他一些过滤法,这些都在之前的推文中有介绍:机器学习中的特征选择(变量筛选)方法简介
在caret中通过sbf函数实现交叉验证的过滤法。
单变量过滤法(Univariate Filters)
在caret
中使用sbf()
函数实现。
基本使用语法:
代码语言:javascript复制sbf(predictors, outcome, sbfControl = sbfControl(), ...)
## or
sbf(formula, data, sbfControl = sbfControl(), ...)
sbf()
的参数解释如下:
functions
:用于设置模型拟合、预测和特征选择的一系列函数,可以是lmSBF(线性回归),rfSBF(随机森林),treebagSBF(袋装决策树),ldaSBF(线性判别分析法),nbSBF(朴素贝叶斯)和caretSBF(自定义函数)。method
:指定抽样方法,可以是boot(BootStrap抽样),cv(交叉验证抽样),LOOCV(留一交叉验证法)和LGOCV(留组交叉验证法)。saveDetails
:是否保存特征选择过程中的预测值和变量重要性,默认为FALSE。number
:指定折数或者重抽样迭代次数,当method为cv或repeatedcv时,则默认从总体中抽取10份样本并迭代10次,否则抽取25份并迭代25次。repeats
:指定抽样组数,默认抽取一组样本。verbose
:是否返回每次重抽样的详细信息,默认为FALSE。returnResamp
:返回重抽样的汇总信息。p
:如果指定method为LGOCV时,该参数起作用,指定训练集的比重。seeds
:为抽样设置随机种子。allowParallel
:在并行后台已加载和允许的情况下,是否允许并行运算。
下面是演示,使用随机森林,10折交叉验证,筛选变量
代码语言:javascript复制library(caret)
代码语言:javascript复制## Loading required package: ggplot2
代码语言:javascript复制## Warning: package 'ggplot2' was built under R version 4.2.3
代码语言:javascript复制## Loading required package: lattice
代码语言:javascript复制# 加载后会在当前环境下出现自变量数据框bbbDescr,因变量是logBBB
data(BloodBrain)
dim(bbbDescr)
代码语言:javascript复制## [1] 208 134
可以看到这个数据有208行,134列!也就是有134个自变量!
下面我们用过滤法去掉一部分:
代码语言:javascript复制sbfControl <- sbfControl(functions = rfSBF, # 选择随机森林
verbose = FALSE,
seeds = c(1:11),# 需要重抽样次数 1个整数
method = "cv")
set.seed(1)
RFwithGAM <- sbf(bbbDescr, logBBB,
sbfControl = sbfControl
)
RFwithGAM
代码语言:javascript复制##
## Selection By Filter
##
## Outer resampling method: Cross-Validated (10 fold)
##
## Resampling performance:
##
## RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 0.5032 0.588 0.3802 0.08017 0.1058 0.05726
##
## Using the training set, 88 variables were selected:
## tpsa, vsa_hyd, a_aro, peoe_vsa.1, peoe_vsa.3...
##
## During resampling, the top 5 selected variables (out of a possible 99):
## a_acc (100%), a_acid (100%), a_aro (100%), achg (100%), adistd (100%)
##
## On average, 88.7 variables were selected (min = 86, max = 91)
查看筛选出的变量:
代码语言:javascript复制RFwithGAM$optVariables
代码语言:javascript复制## [1] "tpsa" "vsa_hyd" "a_aro"
## [4] "peoe_vsa.1" "peoe_vsa.3" "peoe_vsa.5"
## [7] "peoe_vsa.1.1" "peoe_vsa.5.1" "peoe_vsa.6.1"
## [10] "a_acc" "a_acid" "vsa_acc"
## [13] "vsa_acid" "vsa_base" "vsa_don"
## [16] "vsa_other" "slogp_vsa0" "slogp_vsa1"
## [19] "slogp_vsa2" "slogp_vsa6" "slogp_vsa7"
## [22] "slogp_vsa8" "smr_vsa0" "smr_vsa2"
## [25] "smr_vsa4" "smr_vsa5" "tpsa.1"
## [28] "logp.o.w." "frac.cation7." "andrewbind"
## [31] "rotatablebonds" "mlogp" "clogp"
## [34] "nocount" "hbdnr" "rule.of.5violations"
## [37] "prx" "pol" "inthb"
## [40] "adistm" "adistd" "polar_area"
## [43] "nonpolar_area" "psa_npsa" "tcsa"
## [46] "tcpa" "tcnp" "most_negative_charge"
## [49] "most_positive_charge" "sum_absolute_charge" "dipole_moment"
## [52] "ppsa3" "pnsa2" "pnsa3"
## [55] "fpsa2" "fpsa3" "fnsa2"
## [58] "fnsa3" "wpsa3" "wnsa3"
## [61] "dpsa3" "rpcg" "wncs"
## [64] "sadh1" "sadh2" "sadh3"
## [67] "chdh1" "chdh2" "chdh3"
## [70] "scdh1" "scdh2" "scdh3"
## [73] "saaa1" "saaa2" "saaa3"
## [76] "chaa1" "chaa3" "scaa1"
## [79] "scaa2" "scaa3" "ctdh"
## [82] "ctaa" "mchg" "achg"
## [85] "rdta" "n_sp2" "n_sp3"
## [88] "o_sp2"
一下子就从134个自变量晒到了88个!
以上就是caret
中过滤法简单的演示,更多的使用方法大家自己探索,但是说实话不是很好用......