机器学习-决策树的优化

2019-08-29 16:24:08 浏览数 (1)

微信公众号:yale记 关注可了解更多的教程问题或建议,请公众号留言。

背景介绍

今天我们会使用真实的数据来建一棵决策树,编写代码,将其可视化,这样您即可明白决策树是如何在幕后工作的。这里我们使用sklearn中自带的数据集Iris flower data set,该数据集由来自三种鸢尾 ( Iris setosa , Iris virginica和Iris versicolor )中的每一种的50个样品组成。从每个样品测量四个特征 :萼片和花瓣的长度和宽度,以厘米为单位。基于这四个特征的组合,Fisher开发了一种线性判别模型,以区分物种。

入门示例

代码块:

代码语言:javascript复制
import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree
# ### 加载Iris花数据集:
iris = load_iris()
print(iris.feature_names)
# ### 以上特征分别为 萼片长度 萼片宽度 花瓣长度 花瓣宽度
print(iris.target_names)
# ### 以上代表三种鸢尾花 索引分别为 0 1 2
# ### 打印第一种花的特征单位cm :
print(iris.data[0])
# ### 打印第一种花的标签 这里为0 是setosa :
print(iris.target[0])
for i in range(len(iris.target)):
    print("%d: 花名标签:%s,特征:%s" % 
        (i,iris.target[i],iris.data[i]))
# ### 定义一个索引位置列表 0 50 100 
# ### 分别代表三种花的特征和标签的索引
test_idx = [0,50,100]
#训练数据
train_target = np.delete(iris.target,test_idx)
train_data = np.delete(iris.data,test_idx,axis=0)
#测试数据
test_target = iris.target[test_idx]
test_data = iris.data[test_idx]
print(test_target)
print(test_data)
clf = tree.DecisionTreeClassifier()
clf.fit(train_data,train_target)
print(clf.predict(test_data))
#可视化决策树
import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None, 
                      feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                      filled=True, rounded=True,  
                     special_characters=True)  
graph = graphviz.Source(dot_data)  
graph.render("iris-color") 

0 人点赞