9 决策树(Decision Tree)
9. 1 决策树原理
9.2 信息增益与基尼不纯度
信息熵(约翰·香农 1948《通信的数学原理》,一个问题不确定性越大,需要获取的信息就越多,信息熵就越大;一个问题不确定性越小,需要获取的信息就越少,信息熵就越小)
集合D中第k类样本的比率为pk,(k=1,2,…|y|)
信息增益(Information Gain):划分数据前后数据信息熵的差值。信息增益纯度越高,纯度提升越大;信息增益纯度越低,纯度提升越小。
基尼不纯度
基尼不纯度反映从集合D中随机取两个样本后,其类别不一致性的概率。
方法 | 算法 |
---|---|
信息增益 | ID3(改进C4.5) |
基尼不纯度 | CART |
9.3 决策树分类(Decision Tree Classifier)
9.3.1类、属性和方法
类
代码语言:javascript复制class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, ccp_alpha=0.0)
参数
属性 | 类型 | 解释 |
---|---|---|
max_depth | int, default=None | 树的最大深度。如果没有,则节点将展开,直到所有叶都是纯的,或者直到所有叶都包含少于min_samples_split samples的值。 |
criterion | {'gini', 'entropy'}, default='gini' | 测量分割质量的函数。支持的标准是基尼杂质的'基尼'和信息增益的'熵'。 |
属性
属性 | 解释 |
---|---|
classes_ | ndarray of shape (n_classes,) or list of ndarray类标签(单输出问题)或类标签数组列表(多输出问题)。 |
feature_importances_ | ndarray of shape (n_features,)返回功能重要性。 |
max_features_ | intmax_features的推断值。 |
n_classes_ | int or list of int类的数量(对于单个输出问题),或包含每个输出的类的数量的列表(对于多输出问题)。 |
n_features_ | int执行拟合时的特征数。 |
n_outputs_ | int执行拟合时的输出数。 |
tree_ | Tree instance树实例基础树对象。请参阅帮助(sklearn.tree._tree.Tree)对于树对象的属性,了解决策树结构对于这些属性的基本用法。 |
方法
apply(X[, check_input]) | 返回每个样本预测为的叶的索引。 |
---|---|
cost_complexity_pruning_path(X, y[, …]) | 在最小代价复杂度修剪过程中计算修剪路径。 |
decision_path(X[, check_input]) | 返回树中的决策路径。 |
fit(X, y[, sample_weight, check_input, …]) | 从训练集(X,y)构建决策树分类器。 |
get_depth() | 返回决策树的深度。 |
get_n_leaves() | 返回决策树的叶数。 |
get_params([deep]) | 获取此估计器的参数。 |
predict(X[, check_input]) | 预测X的类或回归值。 |
predict_log_proba(X) | 预测输入样本X的类对数概率。 |
predict_proba(X[, check_input]) | 预测输入样本X的类概率。 |
score(X, y[, sample_weight]) | 返回给定测试数据和标签的平均精度。 |
set_params(**params) | 设置此估计器的参数。 |
9.3.2用散点图来分析鸢尾花数据
代码语言:javascript复制def iris_of_decision_tree():
myutil = util()
iris = datasets.load_iris()
# 仅选前两个特征
X = iris.data[:,:2]
y = iris.target
X_train,X_test,y_train,y_test = train_test_split(X, y)
for max_depth in [1,3,5,7]:
clf = DecisionTreeClassifier(max_depth=max_depth)
clf.fit(X_train,y_train)
title=u"鸢尾花数据测试集(max_depth=" str(max_depth) ")"
myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
myutil.draw_scatter(X,y,clf,title)
myutil.plot_learning_curve(DecisionTreeClassifier(max_depth=max_depth),X,y,title)
myutil.show_pic(title)
输出
代码语言:javascript复制鸢尾花数据测试集(max_depth=1):
64.29%
鸢尾花数据测试集(max_depth=1):
57.89%
鸢尾花数据测试集(max_depth=3):
83.93%
鸢尾花数据测试集(max_depth=3):
71.05%
鸢尾花数据测试集(max_depth=5):
85.71%
鸢尾花数据测试集(max_depth=5):
73.68%
鸢尾花数据测试集(max_depth=7):
88.39%
鸢尾花数据测试集(max_depth=7):
65.79%
当max_depth=5的时候效果最好
9.3.3用散点图分析红酒数据
代码语言:javascript复制def wine_of_decision_tree():
myutil = util()
wine = datasets.load_wine()
# 仅选前两个特征
X = wine.data[:,:2]
y = wine.target
X_train,X_test,y_train,y_test = train_test_split(X, y)
for max_depth in [1,3,5]:
clf = DecisionTreeClassifier(max_depth=max_depth)
clf.fit(X_train,y_train)
title=u"红酒数据测试集(max_depth=" str(max_depth) ")"
myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
myutil.draw_scatter(X,y,clf,title)
输出
代码语言:javascript复制红酒数据测试集(max_depth=1):
69.17%
红酒数据测试集(max_depth=1):
64.44%
红酒数据测试集(max_depth=3):
87.97%
红酒数据测试集(max_depth=3):
80.00%
红酒数据测试集(max_depth=5):
90.98%
红酒数据测试集(max_depth=5):
80.00%
红酒数据测试集(max_depth=7):
96.99%
红酒数据测试集(max_depth=7):
73.33%
max_depth=5的时候效果最好;max_depth=7
9.3.4用散点图分析乳腺癌数据
代码语言:javascript复制def wine_of_decision_tree():
myutil = util()
wine = datasets.load_wine()
# 仅选前两个特征
X = wine.data[:,:2]
y = wine.target
X_train,X_test,y_train,y_test = train_test_split(X, y)
for max_depth in [1,3,5]:
clf = DecisionTreeClassifier(max_depth=max_depth)
clf.fit(X_train,y_train)
title=u"乳腺癌数据测试集(max_depth=" str(max_depth) ")"
myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
myutil.draw_scatter(X,y,clf,title)
输出
代码语言:javascript复制乳腺癌数据测试集(max_depth=1):
90.38%
乳腺癌数据测试集(max_depth=1):
85.31%
乳腺癌数据测试集(max_depth=3):
91.08%
乳腺癌数据测试集(max_depth=3):
86.01%
乳腺癌数据测试集(max_depth=5):
93.43%
乳腺癌数据测试集(max_depth=5):
86.71%
乳腺癌数据测试集(max_depth=7):
96.95%
乳腺癌数据测试集(max_depth=7):
86.01%
max_depth=5的时候效果最好