快速入门Python机器学习(18)

2022-09-23 20:03:04 浏览数 (1)

9 决策树(Decision Tree)

9. 1 决策树原理

9.2 信息增益与基尼不纯度

信息熵(约翰·香农 1948《通信的数学原理》,一个问题不确定性越大,需要获取的信息就越多,信息熵就越大;一个问题不确定性越小,需要获取的信息就越少,信息熵就越小)

集合D中第k类样本的比率为pk,(k=1,2,…|y|)

信息增益(Information Gain):划分数据前后数据信息熵的差值。信息增益纯度越高,纯度提升越大;信息增益纯度越低,纯度提升越小。

基尼不纯度

基尼不纯度反映从集合D中随机取两个样本后,其类别不一致性的概率。

方法

算法

信息增益

ID3(改进C4.5)

基尼不纯度

CART

9.3 决策树分类(Decision Tree Classifier)

9.3.1类、属性和方法

代码语言:javascript复制
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, ccp_alpha=0.0)

参数

属性

类型

解释

max_depth

int, default=None

树的最大深度。如果没有,则节点将展开,直到所有叶都是纯的,或者直到所有叶都包含少于min_samples_split samples的值。

criterion

{'gini', 'entropy'}, default='gini'

测量分割质量的函数。支持的标准是基尼杂质的'基尼'和信息增益的'熵'。

属性

属性

解释

classes_

ndarray of shape (n_classes,) or list of ndarray类标签(单输出问题)或类标签数组列表(多输出问题)。

feature_importances_

ndarray of shape (n_features,)返回功能重要性。

max_features_

intmax_features的推断值。

n_classes_

int or list of int类的数量(对于单个输出问题),或包含每个输出的类的数量的列表(对于多输出问题)。

n_features_

int执行拟合时的特征数。

n_outputs_

int执行拟合时的输出数。

tree_

Tree instance树实例基础树对象。请参阅帮助(sklearn.tree._tree.Tree)对于树对象的属性,了解决策树结构对于这些属性的基本用法。

方法

apply(X[, check_input])

返回每个样本预测为的叶的索引。

cost_complexity_pruning_path(X, y[, …])

在最小代价复杂度修剪过程中计算修剪路径。

decision_path(X[, check_input])

返回树中的决策路径。

fit(X, y[, sample_weight, check_input, …])

从训练集(X,y)构建决策树分类器。

get_depth()

返回决策树的深度。

get_n_leaves()

返回决策树的叶数。

get_params([deep])

获取此估计器的参数。

predict(X[, check_input])

预测X的类或回归值。

predict_log_proba(X)

预测输入样本X的类对数概率。

predict_proba(X[, check_input])

预测输入样本X的类概率。

score(X, y[, sample_weight])

返回给定测试数据和标签的平均精度。

set_params(**params)

设置此估计器的参数。

9.3.2用散点图来分析鸢尾花数据

代码语言:javascript复制
def iris_of_decision_tree():
       myutil = util()
       iris = datasets.load_iris()
       # 仅选前两个特征
       X = iris.data[:,:2]
       y = iris.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5,7]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"鸢尾花数据测试集(max_depth=" str(max_depth) ")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)
              myutil.plot_learning_curve(DecisionTreeClassifier(max_depth=max_depth),X,y,title)
              myutil.show_pic(title)

输出

代码语言:javascript复制
鸢尾花数据测试集(max_depth=1):
64.29%
鸢尾花数据测试集(max_depth=1):
57.89%
鸢尾花数据测试集(max_depth=3):
83.93%
鸢尾花数据测试集(max_depth=3):
71.05%
鸢尾花数据测试集(max_depth=5):
85.71%
鸢尾花数据测试集(max_depth=5):
73.68%
鸢尾花数据测试集(max_depth=7):
88.39%
鸢尾花数据测试集(max_depth=7):
65.79%

当max_depth=5的时候效果最好

9.3.3用散点图分析红酒数据

代码语言:javascript复制
def wine_of_decision_tree():
       myutil = util()
       wine = datasets.load_wine()
       # 仅选前两个特征
       X = wine.data[:,:2]
       y = wine.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"红酒数据测试集(max_depth=" str(max_depth) ")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)

输出

代码语言:javascript复制
红酒数据测试集(max_depth=1):
69.17%
红酒数据测试集(max_depth=1):
64.44%
红酒数据测试集(max_depth=3):
87.97%
红酒数据测试集(max_depth=3):
80.00%
红酒数据测试集(max_depth=5):
90.98%
红酒数据测试集(max_depth=5):
80.00%
红酒数据测试集(max_depth=7):
96.99%
红酒数据测试集(max_depth=7):
73.33%

max_depth=5的时候效果最好;max_depth=7

9.3.4用散点图分析乳腺癌数据

代码语言:javascript复制
def wine_of_decision_tree():
       myutil = util()
       wine = datasets.load_wine()
       # 仅选前两个特征
       X = wine.data[:,:2]
       y = wine.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"乳腺癌数据测试集(max_depth=" str(max_depth) ")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)

输出

代码语言:javascript复制
乳腺癌数据测试集(max_depth=1):
90.38%
乳腺癌数据测试集(max_depth=1):
85.31%
乳腺癌数据测试集(max_depth=3):
91.08%
乳腺癌数据测试集(max_depth=3):
86.01%
乳腺癌数据测试集(max_depth=5):
93.43%
乳腺癌数据测试集(max_depth=5):
86.71%
乳腺癌数据测试集(max_depth=7):
96.95%
乳腺癌数据测试集(max_depth=7):
86.01%

max_depth=5的时候效果最好

0 人点赞