R语言机器学习caret-08:过滤法

2023-08-30 13:09:31 浏览数 (1)

之前已经给大家介绍了临床预测模型和机器学习中特征选择(变量选择)常见的方法分类:

  • 机器学习中的特征选择(变量筛选)方法简介

今天就给大家演示过滤法在caret中的实现。

首先要理解过滤法,其实很简单,就是在建立模型前先根据一些标准把一些变量过滤掉,然后再建模。

举个简单的例子,假如你的结果变量是二分类,自变量是数值型,那么对于每一个自变量,我们都可以以结果变量为分组变量,对自变量做方差分析,如果一个自变量在两个类别(也就是两个组别)中没有统计学差异,那这个变量就可以删掉了,因为它在两种类别中没有差别,并不能帮我们判断一个样本到底属于哪种类别。

类似的还有t检验、卡方检验、等等,这些方法的选择在这里主要是根据预测变量和结果变量的类型。比如预测变量是二分类,结果变量也是二分类,此时就可以用卡方检验或者Fisher精确概率法等,如果预测变量是数值型而结果变量是二分类,就可以用方差分析、t检验等。对于学过医学统计学的人来说应该不会很难理解!

除此之外,还有其他一些过滤法,这些都在之前的推文中有介绍:机器学习中的特征选择(变量筛选)方法简介

在caret中通过sbf函数实现交叉验证的过滤法。

单变量过滤法(Univariate Filters)

caret中使用sbf()函数实现。

基本使用语法:

代码语言:javascript复制
sbf(predictors, outcome, sbfControl = sbfControl(), ...)
## or
sbf(formula, data, sbfControl = sbfControl(), ...)

sbf()的参数解释如下:

  • functions:用于设置模型拟合、预测和特征选择的一系列函数,可以是lmSBF(线性回归),rfSBF(随机森林),treebagSBF(袋装决策树),ldaSBF(线性判别分析法),nbSBF(朴素贝叶斯)和caretSBF(自定义函数)。
  • method:指定抽样方法,可以是boot(BootStrap抽样),cv(交叉验证抽样),LOOCV(留一交叉验证法)和LGOCV(留组交叉验证法)。
  • saveDetails:是否保存特征选择过程中的预测值和变量重要性,默认为FALSE。
  • number:指定折数或者重抽样迭代次数,当method为cv或repeatedcv时,则默认从总体中抽取10份样本并迭代10次,否则抽取25份并迭代25次。
  • repeats:指定抽样组数,默认抽取一组样本。
  • verbose:是否返回每次重抽样的详细信息,默认为FALSE。
  • returnResamp:返回重抽样的汇总信息。
  • p:如果指定method为LGOCV时,该参数起作用,指定训练集的比重。
  • seeds:为抽样设置随机种子。
  • allowParallel:在并行后台已加载和允许的情况下,是否允许并行运算。

下面是演示,使用随机森林,10折交叉验证,筛选变量

代码语言:javascript复制
library(caret)
代码语言:javascript复制
## Loading required package: ggplot2
代码语言:javascript复制
## Warning: package 'ggplot2' was built under R version 4.2.3
代码语言:javascript复制
## Loading required package: lattice
代码语言:javascript复制
# 加载后会在当前环境下出现自变量数据框bbbDescr,因变量是logBBB
data(BloodBrain)
dim(bbbDescr)
代码语言:javascript复制
## [1] 208 134

可以看到这个数据有208行,134列!也就是有134个自变量!

下面我们用过滤法去掉一部分:

代码语言:javascript复制
sbfControl <- sbfControl(functions = rfSBF, # 选择随机森林
                         verbose = FALSE,
                         seeds = c(1:11),# 需要重抽样次数 1个整数
                         method = "cv")
set.seed(1)
RFwithGAM <- sbf(bbbDescr, logBBB,
                 sbfControl = sbfControl
                 )
RFwithGAM
代码语言:javascript复制
## 
## Selection By Filter
## 
## Outer resampling method: Cross-Validated (10 fold) 
## 
## Resampling performance:
## 
##    RMSE Rsquared    MAE  RMSESD RsquaredSD   MAESD
##  0.5032    0.588 0.3802 0.08017     0.1058 0.05726
## 
## Using the training set, 88 variables were selected:
##    tpsa, vsa_hyd, a_aro, peoe_vsa.1, peoe_vsa.3...
## 
## During resampling, the top 5 selected variables (out of a possible 99):
##    a_acc (100%), a_acid (100%), a_aro (100%), achg (100%), adistd (100%)
## 
## On average, 88.7 variables were selected (min = 86, max = 91)

查看筛选出的变量:

代码语言:javascript复制
RFwithGAM$optVariables
代码语言:javascript复制
##  [1] "tpsa"                 "vsa_hyd"              "a_aro"               
##  [4] "peoe_vsa.1"           "peoe_vsa.3"           "peoe_vsa.5"          
##  [7] "peoe_vsa.1.1"         "peoe_vsa.5.1"         "peoe_vsa.6.1"        
## [10] "a_acc"                "a_acid"               "vsa_acc"             
## [13] "vsa_acid"             "vsa_base"             "vsa_don"             
## [16] "vsa_other"            "slogp_vsa0"           "slogp_vsa1"          
## [19] "slogp_vsa2"           "slogp_vsa6"           "slogp_vsa7"          
## [22] "slogp_vsa8"           "smr_vsa0"             "smr_vsa2"            
## [25] "smr_vsa4"             "smr_vsa5"             "tpsa.1"              
## [28] "logp.o.w."            "frac.cation7."        "andrewbind"          
## [31] "rotatablebonds"       "mlogp"                "clogp"               
## [34] "nocount"              "hbdnr"                "rule.of.5violations" 
## [37] "prx"                  "pol"                  "inthb"               
## [40] "adistm"               "adistd"               "polar_area"          
## [43] "nonpolar_area"        "psa_npsa"             "tcsa"                
## [46] "tcpa"                 "tcnp"                 "most_negative_charge"
## [49] "most_positive_charge" "sum_absolute_charge"  "dipole_moment"       
## [52] "ppsa3"                "pnsa2"                "pnsa3"               
## [55] "fpsa2"                "fpsa3"                "fnsa2"               
## [58] "fnsa3"                "wpsa3"                "wnsa3"               
## [61] "dpsa3"                "rpcg"                 "wncs"                
## [64] "sadh1"                "sadh2"                "sadh3"               
## [67] "chdh1"                "chdh2"                "chdh3"               
## [70] "scdh1"                "scdh2"                "scdh3"               
## [73] "saaa1"                "saaa2"                "saaa3"               
## [76] "chaa1"                "chaa3"                "scaa1"               
## [79] "scaa2"                "scaa3"                "ctdh"                
## [82] "ctaa"                 "mchg"                 "achg"                
## [85] "rdta"                 "n_sp2"                "n_sp3"               
## [88] "o_sp2"

一下子就从134个自变量晒到了88个!

以上就是caret中过滤法简单的演示,更多的使用方法大家自己探索,但是说实话不是很好用......

0 人点赞