你是否玩过20个问题的游戏? 游戏的规则很简单:参与游戏的一方在脑海里想着某个事物,其它参与者想他提问题,最多只允许提20个问题,问题的答案也只能用“对”或者“错”回答。问问题的人通过推断分解,逐步缩小猜测事物的范围。
决策树的工作原理与之相似,用户输入一系列数据,机器给出分类答案。下面的流程图就是一个简单的决策树。矩形代表判断节点。椭圆代表叶子节点,表示已得出结论,可以终止运行。从判断节点引出的左右箭头称作分支,它指向另一个判断节点或者叶子节点。
决策树适用于标称型数据,因此数值型数据必须先离散化。决策树的主要优势在于数据形式非常容易理解。它的一个重要任务是提取数据中所蕴含的知识信息。因此决策树可以使用不熟悉的数据集,并从中提取一系列规则,这就是机器学习的过程。
在构造决策树时,我们需要解决的第一个问题是,当前数据集上那个特征在划分数据分类时起决定作用,即先用那个特征进行分类效率最高。为了找到决定性的特征,划分出最好的结果,我们需评估每一个特征。之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据全部属于同一类型,在该分支已完成了分类,无需做进一步分割,否则就要重复 划分数据子集的过程(递归)。直到所有具有相同类型的数据均在一个数据子集内。
我们以下面这个简单的水中生物分类的数据集为例,介绍决策树算法的基本流程。
首先创建数据集:
代码语言:javascript复制def createDataset():
'''创建一个简单的数据集'''
dataset = [ ["yes", "yes", "fish"],
["yes", "yes","fish"],
["yes", "no", "nonfish"],
["no", "yes", "nonfish"],
["no", "yes", "nonfish"]]
featnames =["no surfacing", "flippers"]#特征 名
return dataset, featnames
之后我们需要划分数据集。但如何寻找划当前分数据集的最好的特征呢?标准是什么?划分数据集的最大原则是:将无序的数据变得更加有序。组织杂乱无章的数据的一种方法是 使用信息论度量信息。
集合信息的度量方式成为香农熵,或者简称为熵(Entropy), 这个名字来源于信息论支付 克劳德·香农。熵定义为信息的期望值,在明晰熵的定义之前,我们需直到信息的定义。如果待分类的事物可能划分在多个分类之中,则对应第i个分类的信息定义为:
,
其中,
为选择该分类的概率。
则香农熵为所有类别包含的信息的期望值:
例如,若只有一个分类,则概率为1,熵为0,此时熵最小。若有100个事物,类别各不相同,则分到每个类别的概率均为0.01,熵为 -100*0.01*log2(0.01), 约等于6.644。
计算数据集的熵的代码如下:
代码语言:javascript复制from math import log
def calcEntropy(dataset):
'''计算给定数据集的 香农熵'''
numSamples = len(dataset) #样本(特征向量)个数
labelCounts = dict()
for sample in dataset :
currentLabel = sample[-1]
#有则 1, 无则 0 1
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) 1
entropy = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numSamples #按不同分类标签的数量计算概率
entropy -= prob * log(prob, 2) #计算香农 熵
return entropy
经计算,上述水中生物分类的数据集的熵值为 0.97095。
划当前分数据集的最好的特征就是使信息增益(熵的减少量)最大的那个特征。下面的代码使用循环找出使信息增益最大的那个特征的索引:
代码语言:javascript复制def splitDataset(dataset, axis, value):
'''划分数据集
3个输入参数分别为:待划分的数据集、待划分特征的索引,用于划分的特征的值'''
retDataset = []
for featVec in dataset: #for 数据集中每个样本(特征向量)
if featVec [axis] == value:
reducedFeatVec = featVec[: axis]
reducedFeatVec.extend(featVec[axis 1 : ])
retDataset.append(reducedFeatVec)
return retDataset
def chooseBestFeatureToSplit(dataset):
'''选择最好的(最大化信息增益)数据集划分方式'''
numFeatures =len(dataset[0]) -1 # 特征个数(列数 减掉分类标签所占一列)
baseEntropy = calcEntropy(dataset)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [sample[i] for sample in dataset]#第i个特征的所有特征的值
uniqueValues = set(featList) # 通过列表转集合去重,得到第i个特征的值的集合
newEntropy = 0.0
for value in uniqueValues :
subDataset = splitDataset(dataset, i, value)
prob = len(subDataset)/ float(len(dataset))
newEntropy = prob * calcEntropy(subDataset)
#print(value, prob, newEntropy)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain :
bestInfoGain = infoGain
bestFeatureAxis = i
return bestFeatureAxis
现在,我们依据最好的特征就可以依靠递归调用得出决策树的全部结构。本例中,决策树的数据结构用 嵌套的字典来表示。
代码语言:javascript复制def majorityCnt(classList):
classCount = dict()
for vote in classList:
classCount[vote] = classCount.get(vote, 0) 1 #有则 1, 无则 0 1
import operator
#对键值对组成的列表按值从大到小排序
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse =True)
#返回投票最多的分类标签名 。(其实不必全部排序)
return sortedClassCount[0][0]
def createTree(dataset, featnames):
classList = [sample[-1] for sample in dataset] # 类别 列表
if classList.count(classList[0]) == len(classList) :#类别完全相同则停止划分
return classList[0]
if len(dataset[0]) == 1: # 遍历完所有特征,则返回出现次数最多的类别
return majorityCnt(classList)
bestFeatureAxis = chooseBestFeatureToSplit(dataset)
bestFeatureName = featnames[bestFeatureAxis]
myTree = {bestFeatureName: {}}
del featnames[bestFeatureAxis] #删除最佳特征名
featValues = [sample[bestFeatureAxis] for sample in dataset]
uniqueValues = set(featValues) #集合去重
for value in uniqueValues:
subFeatnames = featnames[:] #深拷贝
myTree[bestFeatureName][value] = createTree(splitDataset(dataset, bestFeatureAxis, value),
subFeatnames)
return myTree
调用createTree() 函数,即可得到本例数据集对应的决策树字典为:
不够直观对不对?下面的代码是用matplotlib画出决策树(入口函数是 createPlot()):
代码语言:javascript复制def getNumLeafs(myTree):
'''返回叶子节点的数目(树的最大宽度)'''
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":#数据类型为字典(还有子树)
numLeafs = getNumLeafs(secondDict[key])
else:
numLeafs = 1
return numLeafs
def getTreeDepth(myTree):
'''返回树的最大深度'''
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":#数据类型为字典(还有子树)
thisDepth = 1 getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
import matplotlib.pyplot as plt
# maptlot annotate 的 bbox的 属性字典
decisionNode = dict(boxstyle ="sawtooth", facecolor = "orange",edgecolor = "orange")
leafNode = dict(boxstyle = "round4", facecolor = "lime")
arrow_args = dict(arrowstyle = "<-", color ="r")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy =parentPt, xycoords = "axes fraction",
xytext = centerPt, textcoords ="axes fraction", va ="center",
ha = "center", bbox = nodeType, color ="black",weight ="bold",
arrowprops = arrow_args)
def plotMidText(cntrPt, parentPt, textString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 cntrPt[1]
createPlot.ax1.text(xMid, yMid, textString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff (1.0 numLeafs) /2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr " ?", cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = []) #不显示x轴和y轴的刻度
createPlot.ax1 = plt.subplot(111, frameon= False, ** axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5,1.0), '')
plt.show()