Python+sklearn决策树算法使用入门

2019-05-29 00:37:52 浏览数 (1)

在学习决策树算法之前,首先介绍几个相关的基本概念。

决策树算法原理与sklearn实现

简单地说,决策树算法相等于一个多级嵌套的选择结构,通过回答一系列问题来不停地选择树上的路径,最终到达一个表示某个结论或类别的叶子节点,例如有无贷款意向、能够承担的理财风险等级、根据高考时各科成绩填报最合适的学校和专业、一个人的诚信度、商场是否应该引进某种商品、预测明天是晴天还是阴天。

决策树属于有监督学习算法,需要根据已知样本来训练并得到一个可以工作的模型,然后再使用该模型对未知样本进行分类。

在决策树算法中,构造一棵完整的树并用来分类的计算量和空间复杂度都非常高,可以采用剪枝算法在保证模型性能的前提下删除不必要的分支。剪枝有预先剪枝和后剪枝两大类方法,预先剪枝是在树的生长过程中设定一个指标,当达到指标时就停止生长,当前节点为叶子节点不再分裂,适合大样本集的情况,但有可能会导致模型的误差比较大。后剪枝算法可以充分利用全部训练集的信息,但计算量要大很多,一般用于小样本的情况。

决策树常见的实现有ID3(Iterative Dichotomiser 3)、C4.5、C5.0和CART,ID3、C4.5、C5.0是属于分类树,CART属于分类回归树。其中ID3以信息论为基础,以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类。ID3算法从根节点开始,在每个节点上计算所有可能的特征的信息增益,选择信息增益最大的一个特征作为该节点的特征并分裂创建子节点,不断递归这个过程直到完成决策树的构建。ID3适合二分类问题,且仅能处理离散属性。

C4.5是对ID3的一种改进,根据信息增益率选择属性,在构造树的过程中进行剪枝操作,能够对连续属性进行离散化。该算法先将特征取值排序,以连续两个值中间值作为划分标准。尝试每一种划分,并计算修正后的信息增益,选择信息增益最大的分裂点作为该属性的分裂点。

分类与回归树CART(Classification And Regression Tree)以二叉树的形式给出,比传统的统计方法构建的代数预测准则更加准确,并且数据越复杂、变量越多,算法的优越性越显著。

扩展库sklearn.tree中使用CART算法的优化版本实现了分类决策树DecisionTreeClassifier和回归决策树DecisionTreeRegressor,官方在线帮助文档为https://scikit-learn.org/stable/modules/tree.html。本文重点介绍分类决策树DecisionTreeClassifier的用法,该类构造方法的语法为:

__init__(self, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)

其中,常用参数及含义如下表所示。

表 DecisionTreeClassifier类构造方法参数及含义

参数名称

含义

criterion

用来执行衡量分裂(创建子节点)质量的函数,取值为'gini'时使用基尼值,为'entropy'时使用信息增益

splitter

用来指定在每个节点选择划分的策略,可以为'best'或'random'

max_depth

用来指定树的最大深度,如果不指定则一直扩展节点,直到所有叶子包含的样本数量少于min_samples_split,或者所有叶子节点都不再可分

min_samples_split

用来指定分裂节点时要求的样本数量最小值,值为实数时表示百分比

min_samples_leaf

叶子节点要求的样本数量最小值

max_features

用来指定在寻找最佳分裂时考虑的特征数量

max_leaf_nodes

用来设置叶子最大数量

min_impurity_decrease

如果一个节点分裂后可以使得不纯度减少的值大于等于min_impurity_decrease,则对该节点进行分裂

min_impurity_split

用来设置树的生长过程中早停的阈值,如果一个节点的不纯度高于这个阈值则进行分裂,否则为一个叶子不再分裂

presort

用来设置在拟合时是否对数据进行预排序来加速寻找最佳分裂的过程

该类对象常用方法如下表所示。

表 DecisionTreeClassifier类常用方法

方法

功能

fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None)

根据给定的训练集构建决策树分类器

predict_log_proba(self, X)

预测样本集X属于不同类别的对数概率

predict_proba(self, X, check_input=True)

预测样本集X属于不同类别的概率

apply(self, X, check_input=True)

返回每个样本被预测的叶子索引

decision_path(self, X, check_input=True)

返回树中的决策路径

predict(self, X, check_input=True)

返回样本集X的类别或回归值

score(self, X, y, sample_weight=None)

根据给定的数据和标签计算模型精度的平均值

另外,sklearn.tree模块的函数export_graphviz()可以用来把训练好的决策树数据导出,然后再使用扩展库graphviz中的功能绘制决策树图形,export_graphviz()函数语法为

export_graphviz(decision_tree, out_file="tree.dot", max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3)

为了能够绘制图形并输出文件,需要从下面的地址下载graphviz安装包,安装之后把安装路径的bin文件夹路径添加至系统环境变量Path。

https://graphviz.gitlab.io/_pages/Download/windows/graphviz-2.38.msi

然后执行下面的代码:

代码运行后生成的abc.pdf文件内容:

0 人点赞