《机器学习实战》 - 决策树

2022-04-01 15:28:22 浏览数 (1)

简介

创建分支的伪代码函数createBranch()如下

代码语言:javascript复制
检测数据集中的每个子项是否属于同一分类:
    If so return 类标签;
    Else
        寻找划分数据集的最好特征
        划分数据集
        创建分支节点
            for 每个划分的子集
                调用函数createBranch并增加返回结果到分支节点中
        return 分支节点

决策树的一般流程

  1. 收集数据:可以使用任何方法。
  2. 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
  4. 训练算法:构造树的数据结构。
  5. 测试算法:使用经验树计算错误率。
  6. 使用算法:此步骤可以适用于任何监督学算法,而使用决策树可以更好地理解数据的内在含义。

一些决策树算法 采用 二分法 划分数据,本文并不采用此方法, 若依据某个属性划分数据将会产生4个可能的值,我们将把数据划分成4块,并创建4个不同的分支。 本文 将使用 ID3算法 划分数据集, 该算法 处理 如何划分数据集,何时停止划分数据集。 每次划分数据集时,我们只选取一个特征属性,若训练集中存在20个特征,第一次我们选择哪个特征作为划分的参考属性呢?

表3-1 海洋生物数据

不浮出水面是否可以生存

是否有脚蹼

属于鱼类

1

2

3

4

5

信息增益

划分数据集的大原则是:将无序的数据变得更加有序。 我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。 组织杂乱无章数据的一种方法就是使用 信息论度量信息,信息论是量化处理信息的分支科学。 我们可以在划分数据前后使用信息论量化度量信息的内容。 信息增益(information gain)和熵(entropy)

在划分数据集之前之后 信息发生的变化 称为 信息增益

知道如何计算信息增益,我们就可以计算 每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德·香农

熵定义为信息的期望值,

信息的定义:

若待分类的事务可能划分在多个分类之中,则符号 xi信息 定义为:l(xi)=−log2⁡p(xi) 其中,p(x)选择该分类的概率

为了计算熵,我们需要计算 所有类别 所有可能值 包含的 信息期望值 ,通过下方公式得到:

H=−∑i=1np(xi)log2⁡p(xi)

其中,n 是分类数目

代码语言:javascript复制
from math import log

# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: # 类别和它们的出现次数: 字典:key: 分类 value: 此分类出现次数
        currentLabel = featVec[-1] # 最后一列是分类标签
        if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0
        labelCounts[currentLabel]  = 1
    shannonEnt = 0.0
    for key in labelCounts:
        # 此分类概率(此分类出现次数/总样本个数)
        prob = float(labelCounts[key])/numEntries
        # 计算 香农熵
        shannonEnt -= prob * log(prob, 2) # 以2为底 求对数

    return shannonEnt
代码语言:javascript复制
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']

    return dataSet, labels
代码语言:javascript复制
if __name__ == "__main__":
    # 1.
    myDat, labels = createDataSet()
    print(myDat) # [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    print(calcShannonEnt(myDat)) # 0.9709505944546686

熵越高,则混合的数据也越多,

我们可以在数据集中添加更多的分类,观察熵是如何变化的。

这里我们增加第三个名为maybe的分类,测试熵的变化:

代码语言:javascript复制
# 2.
   # 将第一行样本的分类改为 'maybe'
   myDat[0][-1] = 'maybe'
   print(calcShannonEnt(myDat)) # 1.3709505944546687

得到熵之后,就可以按照获取最大信息增益的方法划分数据集,

下面 将具体学如何划分数据集以及如何度量信息增益。

另一个度量集合无序程度的方法是基尼不纯度2 (Gini impurity), 简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。 本文 不采用基尼不纯度方法,这里就不再做进一步的介绍

划分数据集

分类算法除了需要测量信息熵,还需要划分数据集, 度量划分数据集的熵,以便判断当前是否正确地划分了数据集。 我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式 想象一个分布在二维空间的数据散点图,需要在数据之间划条线,将它们分成两部分, 我们应该按照x轴还是y轴划线呢?答案就是本节讲述的内容。

代码语言:javascript复制
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
    """

    :param dataSet: 待划分的数据集
    :param axis: 划分数据集的特征
    :param value: 需要返回的特征的值
    :return:
    """
    retDataSet = []
    # 遍历每一个样本 的特征向量
    for featVec in dataSet:
        # 某个特征值 为 目标值
        if featVec[axis] == value:
            # 取 此特征 前 特征向量
            reducedFeatVec = featVec[:axis]
            # 添加 此特征 后 特征向量
            reducedFeatVec.extend(featVec[axis 1:])
            # 添加到 要返回的数据集中
            retDataSet.append(reducedFeatVec)

    return retDataSet
代码语言:javascript复制
if __name__ == "__main__":
    # 3.
    myDat, labels = createDataSet()
    print(splitDataSet(myDat, 0, 1)) # [[1, 'yes'], [1, 'yes'], [0, 'no']]
    print(splitDataSet(myDat,0, 0)) # [[1, 'no'], [1, 'no']]
代码语言:javascript复制
print(splitDataSet(myDat, 0, 1)) # [[1, 'yes'], [1, 'yes'], [0, 'no']]

因为只有红色框 处(axis=0)满足 value=1,然后取到绿色框处

接下来我们将遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。

熵计算将会告诉我们如何划分数据集是最好的数据组织方式。

代码语言:javascript复制
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 # numFeatures个特征, 最后一列:标签列
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures): # 遍历所有的特征
        # 从数据集 中 取出 所有样本 的 此i特征 组成 此特征向量
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList) # 去重
        newEntropy = 0.0
        for value in uniqueVals:
            # 按此特征 切分 数据集
            subDataSet = splitDataSet(dataSet, i, value)
            # 概率
            prob = len(subDataSet)/float(len(dataSet))
            # 累加
            newEntropy  = prob * calcShannonEnt(subDataSet)
        # 信息增益
        infoGain = baseEntropy - newEntropy
        # 只要 最好(最大) 的信息增益
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            # 标记 最大 信息增益 所对应的特征 索引
            bestFeature = i

    return bestFeature
代码语言:javascript复制
if __name__ == "__main__":
    # 4.
    myDat, labels = createDataSet()
    print(chooseBestFeatureToSplit(myDat)) # 0
    print(myDat) # [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

代码运行结果告诉我们,索引为0的特征(即第一个特征)是__最好的用于划分数据集的特征__。

若按照 第一个特征属性划分数据 ,也就是说 第一个特征是1的放在一个组,第一个特征是0的放在另一个组,数据一致性如何?

按照上述的方法划分数据集,第一个特征为1的海洋生物分组将有两个属于鱼类,一个属于非鱼类;另一个分组则全部属于非鱼类。

递归构建决策树

目前已完成从数据集构造决策树算法所需要的子功能模块,其工作原理如下:

  1. 得到原始数据集
  2. 基于最好的属性值 划分数据集
  3. 由于特征值 可能多于2个,因此可能存在大于两个分支的数据集划分

第一次划分后,数据将被 向下传递到树分支 的下一个节点, 在这个节点上,我们可以额再次划分数据。 因此,我们采用递归 原则 处理数据

递归结束条件:(满足下方其一即可结束)

  1. 程序遍历完所有划分数据集的属性
  2. 每个分支下的所有实例都具有相同的分类

若 所有实例 具有相同的分类,则得到一个叶子节点或者终止块。 任何到达叶子节点的数据必然属于叶子节点的分类,如下图

第一个结束条件使得算法可以终止,我们甚至可以设置算法可以划分的最大分组数目。

后续还会介绍其他决策树算法,如 C4.5 和 CART,这些算法在运行时并不总是在每次划分分组时都会消耗特征。

由于特征数目并不是在每次划分数据分组时都减少,因此这些算法在实际使用时可能引起一定问题。

目前我们并不需要考虑这个问题,只需要在算法开始运行前计算列的数目,查看算法是否使用了所有属性即可。

若数据集已处理所有属性,但类标签依然不是唯一,此时我们需要决定如何定义该叶子节点,

在这种情况下,我们通常会采用__多数表决__的方法 决定该叶子节点的分类。

代码语言:javascript复制
import operator


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote]  = 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]

创建树的函数代码

代码语言:javascript复制
# 创建树的函数代码
def createTree(dataSet, labels):
    """
    创建树
    :param dataSet: 数据集
    :param labels: 标签列表
    :return:
    """
    # 取出每个样本的标签,组成 标签向量
    classList = [example[-1] for example in dataSet]
    # 若 第一个标签值数目 == 标签个数
    if classList.count(classList[0]) == len(classList):
        # 停止 划分 当 所有的类别相等时
        return classList[0]
    if len(dataSet[0]) == 1:
        # 停止 划分 当 数据集中没有更多的特征(只剩一个特征)
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        # 复制所有的标签,因此 树不会 搞混 存在的标签
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

    return myTree
代码语言:javascript复制
if __name__ == "__main__":
    # 5.
    myDat, labels = createDataSet()
    myTree = createTree(myDat, labels)
    print(myTree) # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

在Python中使用Matplotlib注解绘制树形图

treePlotter.py 决策树的主要优点: 直观易于理解

Matplotlib注解

代码语言:javascript复制
import matplotlib.pyplot as plt

# (以下三行)定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


# (以下两行)绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)  # ticks for demo puropses
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()


if __name__ == "__main__":
    # 1.
    createPlot()

构造注解树

代码语言:javascript复制
# 获取叶节点的数目
def getNumLeafs(myTree):
    # 叶节点 个数
    numLeafs = 0
    # 根节点
    firstStr = list(myTree)[0]
    # 根节点下 内容
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 是字典类型,说明不是叶子节点
            # 递归下去
            numLeafs  = getNumLeafs(secondDict[key])
        else:
            # 叶子节点 数 1
            numLeafs  = 1

    return numLeafs

# 获取树深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 字典类型:说明 不是叶子节点
            # 递归下去
            thisDepth = 1   getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            # 只要最大的深度
            maxDepth = thisDepth

    return maxDepth

# 检索树
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]

    return listOfTrees[i]
代码语言:javascript复制
if __name__ == "__main__":
    # 2.
    print(retrieveTree(1))
    myTree = retrieveTree(0) # {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    print(getNumLeafs(myTree)) # 3
    print(getTreeDepth(myTree)) # 2

函数retrieveTree()主要用于测试,返回预定义的树结构

现在我们可以将前面学到的方法组合在一起,绘制一棵完整的树。

更新 createPlot

代码语言:javascript复制
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 存放 树宽
    plotTree.totalW = float(getNumLeafs(inTree))
    # 存放 树深度(高)
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

添加如下代码:

代码语言:javascript复制
# (以下四行)在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0   cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0   cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    # 对于这个节点 的 文本标签
    firstStr = list(myTree)[0]
    cntrPt = (plotTree.xOff   (1.0   float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 不是叶子节点,递归进去
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff   1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff   1.0/plotTree.totalD
代码语言:javascript复制
if __name__ == "__main__":
    # 3.1
    myTree = retrieveTree(0)
    createPlot(myTree)

接下来,变更字典,重新绘制树形图:

代码语言:javascript复制
if __name__ == "__main__":
    # 3.2
    myTree = retrieveTree(0)
    myTree['no surfacing'][3] = 'maybe'
    print(myTree)
    createPlot(myTree)

测试和存储分类器

测试算法:使用决策树执行分类

trees.py

代码语言:javascript复制
# 使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
    # 根节点
    firstStr = list(inputTree)[0]
    secondDict = list(inputTree)[0]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        # 字典说明非叶子节点,递归下去
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat

    return classLabel
代码语言:javascript复制
if __name__ == "__main__":
    # 6.
    import treePlotter

    myDat, labels = createDataSet()
    print(labels) # ['no surfacing', 'flippers']

    myTree = treePlotter.retrieveTree(0)
    print(myTree) # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    print(classify(myTree, labels, [1, 0])) # o
    print(classify(myTree, labels, [1, 1])) # o

使用算法:决策树的存储

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,则可以很快完成。 因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。 为了解决这个问题,需要使用Python模块pickle序列化对象,参见程序清单3-9。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。 使用pickle模块存储决策树

代码语言:javascript复制
# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)
代码语言:javascript复制
if __name__ == "__main__":
    # 7.
    import treePlotter
    myTree = treePlotter.retrieveTree(0)
    storeTree(myTree, 'classifierStorage.txt')
    print(grabTree('classifierStorage.txt')) # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

决策树的优点之一:可以持久化分类器,而k-近邻算法就无法持久化分类器,每次都需要重新学

示例:使用决策树预测隐形眼镜类型

代码语言:javascript复制
if __name__ == "__main__":
    # 8. 示例:使用决策树预测隐形眼镜类型
    fr = open('lenses.txt')
    lenses = [inst.strip().split('t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLabels)
    # {'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'young': 'hard', 'pre': 'no lenses', 'presbyopic': 'no lenses'}}, 'myope': 'hard'}}, 'no': {'age': {'young': 'soft', 'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}}}, 'reduced': 'no lenses'}}
    print(lensesTree)

    import treePlotter
    treePlotter.createPlot(lensesTree)

图: 由ID3算法产生的决策树 图所示的决策树非常好地匹配了实验数据,然而 匹配选项过多,这种问题称为:过度匹配(overfitting)/过拟合。 为了减少过度匹配问题,可以裁剪决策树,去掉不必要叶子节点(若此叶子结点仅增加少许信息),将它并入其他叶子节点中,即合并相邻的无法产生大量信息增益的叶子节点。

小结

ID3算法 无法直接处理数值型数据,尽管可通过量化将数值型转为标称型数值,但若特征太多,ID3仍会面临其他问题。

ID3可划分标称型数值

构建决策树一般采用递归将数据集转为决策树,一般用字典存储树节点信息

测量集合中数据不一致性(熵),寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。

其他决策树算法,最流行:C4.5、CART

参考

感谢帮助!

  • 《机器学实战》美 Peter Harrington

本文作者: yiyun

本文链接: https://cloud.tencent.com/developer/article/1970711

版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!

0 人点赞