回归树(一)

2019-08-14 17:29:16 浏览数 (1)

线性回归模型需要拟合全部的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间的关系十分复杂时,构建全局模型的想法就不切实际。一种可行的方法是将数据集切分成很多份容易建模的数据,然后再用线性回归技术来建模。如果切分后任然难以用线性模型拟合就继续切分。在这种切分方式下,递归和树结构就相当有用。

本篇介绍一个叫做CART(Classfication And Regression Trees,分类回归树)的算法。先介绍一种简单的回归树,在每个叶子节点使用y的均值做预测。

首先加载一个200x3的数据集:

代码语言:python代码运行次数:0复制
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('t')
        fltLine = list(map(float,curLine) )#map all elements to float() # py36
        dataMat.append(fltLine)
    return dataMat

数据集的大小为200x3,前两列为x0(恒为1)和x1的值,最后一列为y的值。x1和y的二维图如下:

回归树使用二元切分来处理连续型变量。具体的处理方法是:如果特征值大于给定的阈值就走左子树,否则就进入右子树。

代码语言:javascript复制
def binSplitDataSet(dataSet, feature, value):
    #mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]#原文错误
    matLeft = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    #mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    matRight = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return matLeft, matRight

递归构建回归树:

代码语言:javascript复制
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#选择最好的特征
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat #根据哪个特征划分
    retTree['spVal'] = val #根据和哪个值的比较结果进行划分
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

树的数据结构使用嵌套的字典实现,字典有4个键值,分别是

"spInd" : 特征的索引

"spVal" : 特征的阈值

"left" : 左子树,若是叶子节点则是该组样本y的均值

"right" : 右子树,若是叶子节点则是该组样本y的均值

使用叶子节点对应的y值的平均值作为预测值:

代码语言:javascript复制
def regLeaf(dataSet):#returns the value used for each leaf    
    return mean(dataSet[:,-1])

这里用平方误差的总和作为误差函数:

代码语言:javascript复制
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

下面给出如何找到最好的划分特征的伪代码:

代码语言:javascript复制
对每个特征:
    对每个不重复的特征值:
        将数据集切分成两份
        计算误差(总方差)
        如果当前误差小于当前最小误差,就用当前最小误差替代当前误差
        如果误差下降值小于给定的最小值TolS, 则不再切分,直接返回
        如果去重的剩余特征值的数目小于TolN,则不再切分,直接返回
返回最佳切分的特征和阈值

代码实现:

代码语言:javascript复制
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1.0, 4)):
    # tolS : 容许的误差下降值
    # tolN:切分的最少样本数
    tolS, tolN = ops
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(array(dataSet[:,featIndex]).flatten().tolist()): # 利用集合去重,set()参数列表不能有嵌套,须先降维
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0)   errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

调用上述函数,求出回归树:

代码语言:javascript复制
myData = loadDataSet('ex0.txt')
myMat = mat(myData)
tree = createTree(myMat)
print(tree)

上面回归树的结果不太直观,我们可以用matplotlib 画出树的结构:

下面我也给出回归树绘图的代码:

代码语言:javascript复制
from plotRegTree import createPlot
createPlot(tree,title="回归树n 以分段常数预测y")

具体的实现在写plotRegTree模块中,会多次用到递归:

代码语言:javascript复制
def getNumLeafs(regTree):
    '''返回叶子节点的数目(树的最大宽度)'''
    numLeafs = 0
    leftTree = regTree['left']
    rightTree = regTree['right']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        numLeafs  = getNumLeafs(leftTree)#递归调用
    else:
        numLeafs  = 1
    if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
        numLeafs  = getNumLeafs(rightTree)#递归调用
    else:
        numLeafs  = 1
       
    return numLeafs
    
def getTreeDepth(regTree):
    '''返回树的最大深度'''
    maxDepth = 0
    leftTree = regTree['left']
    rightTree = regTree['right']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        thisDepth = 1   getTreeDepth(leftTree)#递归调用
    else:
        thisDepth = 1
    if thisDepth >maxDepth :
        maxDepth = thisDepth
   
    if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
        thisDepth = 1   getTreeDepth(rightTree)#递归调用
    else:
        thisDepth = 1
    if thisDepth >maxDepth :
        maxDepth = thisDepth
       
    return maxDepth
    
yTop = 0.97 # 图形区域(含标题)X,Y坐标范围 均为0~1,0.97给title留空间
decisionNode = dict(boxstyle ="sawtooth", facecolor = "orange",edgecolor = "orange")
leafNode = dict(boxstyle = "round4", facecolor = "lime")
arrow_args = dict(arrowstyle = "<-", color ="r")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #return None
    createPlot.ax1.annotate(nodeTxt, xy =parentPt, xycoords = "axes fraction",
                            xytext = centerPt, textcoords ="axes fraction", va ="center",
                            ha = "center", bbox = nodeType, color ="black",weight ="bold",
                            arrowprops = arrow_args)
                            
def plotMidText(cntrPt, parentPt, textString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0   cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0   cntrPt[1]
    createPlot.ax1.text(xMid, yMid, textString)
    
def plotTree(regTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(regTree)
    depth = getTreeDepth(regTree)
    leftTree = regTree['left']
    rightTree = regTree['right']
   
    #firstStr = list(regTree.keys())[0]
    cntrPt = (plotTree.xOff   (1.0   numLeafs) /2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    print(regTree['spInd'])
    plotNode( "根据X%d划分" % regTree['spInd'], cntrPt, parentPt, decisionNode)
   
    plotTree.yOff  -= yTop /plotTree.totalD #到下一层
    specLimit = regTree['spVal']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        plotTree(leftTree, cntrPt, ">%.6f" % specLimit )#递归调用
        #y的预测值的精度(小数点后显示6位)
    else:
            plotTree.xOff = plotTree.xOff   1.0 / plotTree.totalW
            plotNode("y预测值:%.3f" % leftTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, ">%.6f" % specLimit)
    if type(rightTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        plotTree(rightTree, cntrPt, "<=%.6f" % specLimit)#递归调用
    else:
            plotTree.xOff = plotTree.xOff   1.0 / plotTree.totalW
            plotNode("y预测值:%.3f" % rightTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, "<=%.6f" % specLimit)
    plotTree.yOff   = yTop / plotTree.totalD #回到上一层

def createPlot(inTree,title ="回归树"):
    from matplotlib import pyplot as plt
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    axprops = dict(xticks = [], yticks = []) #不显示x轴和y轴的刻度
    createPlot.ax1 = plt.subplot(111, frameon= False, ** axprops)
   
    plotTree.totalW = getNumLeafs(inTree)
    plotTree.totalD = getTreeDepth(inTree)
    plotTree.xOff = -0.5 / plotTree.totalW
   
    plotTree.yOff = yTop
    plotTree(inTree, (0.5, yTop), '')
  
    plt.title(title,fontsize =14, color ="B")
    plt.show()

0 人点赞