线性回归模型需要拟合全部的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间的关系十分复杂时,构建全局模型的想法就不切实际。一种可行的方法是将数据集切分成很多份容易建模的数据,然后再用线性回归技术来建模。如果切分后任然难以用线性模型拟合就继续切分。在这种切分方式下,递归和树结构就相当有用。
本篇介绍一个叫做CART(Classfication And Regression Trees,分类回归树)的算法。先介绍一种简单的回归树,在每个叶子节点使用y的均值做预测。
首先加载一个200x3的数据集:
代码语言:python代码运行次数:0复制def loadDataSet(fileName): #general function to parse tab -delimited floats
dataMat = [] #assume last column is target value
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('t')
fltLine = list(map(float,curLine) )#map all elements to float() # py36
dataMat.append(fltLine)
return dataMat
数据集的大小为200x3,前两列为x0(恒为1)和x1的值,最后一列为y的值。x1和y的二维图如下:
回归树使用二元切分来处理连续型变量。具体的处理方法是:如果特征值大于给定的阈值就走左子树,否则就进入右子树。
代码语言:javascript复制def binSplitDataSet(dataSet, feature, value):
#mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]#原文错误
matLeft = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
#mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
matRight = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return matLeft, matRight
递归构建回归树:
代码语言:javascript复制def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#选择最好的特征
if feat == None: return val #if the splitting hit a stop condition return val
retTree = {}
retTree['spInd'] = feat #根据哪个特征划分
retTree['spVal'] = val #根据和哪个值的比较结果进行划分
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
树的数据结构使用嵌套的字典实现,字典有4个键值,分别是
"spInd" : 特征的索引
"spVal" : 特征的阈值
"left" : 左子树,若是叶子节点则是该组样本y的均值
"right" : 右子树,若是叶子节点则是该组样本y的均值
使用叶子节点对应的y值的平均值作为预测值:
代码语言:javascript复制def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])
这里用平方误差的总和作为误差函数:
代码语言:javascript复制def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
下面给出如何找到最好的划分特征的伪代码:
代码语言:javascript复制对每个特征:
对每个不重复的特征值:
将数据集切分成两份
计算误差(总方差)
如果当前误差小于当前最小误差,就用当前最小误差替代当前误差
如果误差下降值小于给定的最小值TolS, 则不再切分,直接返回
如果去重的剩余特征值的数目小于TolN,则不再切分,直接返回
返回最佳切分的特征和阈值
代码实现:
代码语言:javascript复制def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1.0, 4)):
# tolS : 容许的误差下降值
# tolN:切分的最少样本数
tolS, tolN = ops
#if all the target variables are the same value: quit and return value
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(array(dataSet[:,featIndex]).flatten().tolist()): # 利用集合去重,set()参数列表不能有嵌套,须先降维
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue#returns the best feature to split on
#and the value used for that split
调用上述函数,求出回归树:
代码语言:javascript复制myData = loadDataSet('ex0.txt')
myMat = mat(myData)
tree = createTree(myMat)
print(tree)
上面回归树的结果不太直观,我们可以用matplotlib 画出树的结构:
下面我也给出回归树绘图的代码:
代码语言:javascript复制from plotRegTree import createPlot
createPlot(tree,title="回归树n 以分段常数预测y")
具体的实现在写plotRegTree模块中,会多次用到递归:
代码语言:javascript复制def getNumLeafs(regTree):
'''返回叶子节点的数目(树的最大宽度)'''
numLeafs = 0
leftTree = regTree['left']
rightTree = regTree['right']
if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
numLeafs = getNumLeafs(leftTree)#递归调用
else:
numLeafs = 1
if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
numLeafs = getNumLeafs(rightTree)#递归调用
else:
numLeafs = 1
return numLeafs
def getTreeDepth(regTree):
'''返回树的最大深度'''
maxDepth = 0
leftTree = regTree['left']
rightTree = regTree['right']
if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
thisDepth = 1 getTreeDepth(leftTree)#递归调用
else:
thisDepth = 1
if thisDepth >maxDepth :
maxDepth = thisDepth
if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
thisDepth = 1 getTreeDepth(rightTree)#递归调用
else:
thisDepth = 1
if thisDepth >maxDepth :
maxDepth = thisDepth
return maxDepth
yTop = 0.97 # 图形区域(含标题)X,Y坐标范围 均为0~1,0.97给title留空间
decisionNode = dict(boxstyle ="sawtooth", facecolor = "orange",edgecolor = "orange")
leafNode = dict(boxstyle = "round4", facecolor = "lime")
arrow_args = dict(arrowstyle = "<-", color ="r")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#return None
createPlot.ax1.annotate(nodeTxt, xy =parentPt, xycoords = "axes fraction",
xytext = centerPt, textcoords ="axes fraction", va ="center",
ha = "center", bbox = nodeType, color ="black",weight ="bold",
arrowprops = arrow_args)
def plotMidText(cntrPt, parentPt, textString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 cntrPt[1]
createPlot.ax1.text(xMid, yMid, textString)
def plotTree(regTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(regTree)
depth = getTreeDepth(regTree)
leftTree = regTree['left']
rightTree = regTree['right']
#firstStr = list(regTree.keys())[0]
cntrPt = (plotTree.xOff (1.0 numLeafs) /2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
print(regTree['spInd'])
plotNode( "根据X%d划分" % regTree['spInd'], cntrPt, parentPt, decisionNode)
plotTree.yOff -= yTop /plotTree.totalD #到下一层
specLimit = regTree['spVal']
if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
plotTree(leftTree, cntrPt, ">%.6f" % specLimit )#递归调用
#y的预测值的精度(小数点后显示6位)
else:
plotTree.xOff = plotTree.xOff 1.0 / plotTree.totalW
plotNode("y预测值:%.3f" % leftTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, ">%.6f" % specLimit)
if type(rightTree).__name__ == "dict":#数据类型为字典(左树还有子树)
plotTree(rightTree, cntrPt, "<=%.6f" % specLimit)#递归调用
else:
plotTree.xOff = plotTree.xOff 1.0 / plotTree.totalW
plotNode("y预测值:%.3f" % rightTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, "<=%.6f" % specLimit)
plotTree.yOff = yTop / plotTree.totalD #回到上一层
def createPlot(inTree,title ="回归树"):
from matplotlib import pyplot as plt
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = []) #不显示x轴和y轴的刻度
createPlot.ax1 = plt.subplot(111, frameon= False, ** axprops)
plotTree.totalW = getNumLeafs(inTree)
plotTree.totalD = getTreeDepth(inTree)
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = yTop
plotTree(inTree, (0.5, yTop), '')
plt.title(title,fontsize =14, color ="B")
plt.show()