决策树实战:预测隐形眼镜类型

2022-07-29 19:33:55 浏览数 (1)

数据集[1] 提取码:j50c

数据长这样:

完整代码:

代码语言:javascript复制
from sklearn.tree import DecisionTreeClassifier,export_graphviz
from sklearn.preprocessing import LabelBinarizer
from sklearn.feature_extraction import DictVectorizer
import pydotplus

labelBinarizer=LabelBinarizer()     #方便后面对标签二值化后的标签进行复原

def load_file():
    #读取数据集
    data = open('ensemble/lenses.txt')
    lenses = [];label = ['age','prescript','astigmatic','tearRate']
    feature= [];labels = []
    for line in data.readlines():
        lenses.append(line.strip().split('t'))

    for i in range(len(lenses)):
        row = {}
        for j in range(0,len(lenses[i])-1):
            row[label[j]] = lenses[i][j]
        feature.append(row)
        labels.append(lenses[i][len(lenses[i]) - 1])

    train_x = DictVectorizer().fit_transform(feature).toarray()  #特征提取
    train_y = labelBinarizer.fit_transform(labels)    #标签二值化
    re_label=train_y   #同样方便后面输出预测结果

    #前2/3(16个)做训练集,后1/3(8个)做测试集
    test_x = train_x[int(len(train_x)*2/3):len(train_x)]
    test_y = train_y[int(len(train_y)*2/3):len(train_y)]
    train_x = train_x[0:int(len(train_x)*2/3)]
    train_y = train_y[0:int(len(train_y)*2/3)]

    return train_x,train_y,test_x,test_y,re_label


def decision_Tree():
    train_x, train_y, test_x, test_y, relabel = load_file()
    clf = DecisionTreeClassifier()
    clf = clf.fit(train_x,train_y)   #训练模型

    #可视化
    dot_data = export_graphviz(clf, out_file=None,
                               feature_names = clf.feature_importances_,
                                filled = True, rounded = True, special_characters = True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_pdf('tree.pdf')

    #预测
    pred=clf.predict(test_x)
    original_data=labelBinarizer.inverse_transform(relabel)   #标签二值化后复原
    original_data=original_data[int(len(original_data)*2/3):len(original_data)]  #只截取测试集部分

    #输出
    for i in range(len(test_x)):
        print('正确类别:',original_data[i],'预测类别:',original_data[test_y.tolist().index(pred.tolist()[i])])
    print('分类正确率为:',clf.score(test_x,test_y))


if __name__ == '__main__':
    decision_Tree()

决策树:

输出结果:

References

[1] 数据集: https://pan.baidu.com/s/1DOqNzeeEAEG84OlhBTZfZg

0 人点赞