CART,又名分类回归树,是在ID3的基础上进行优化的决策树,学习CART记住以下几个关键点:
(1)CART既能是分类树,又能是分类树;
(2)当CART是分类树时,采用GINI值作为节点分裂的依据;当CART是回归树时,采用样本的最小方差作为节点分裂的依据;
(3)CART是一棵二叉树。
接下来将以一个实际的例子对CART进行介绍:
表1 原始数据表
看电视时间 | 婚姻情况 | 职业 | 年龄 |
---|---|---|---|
3 | 未婚 | 学生 | 12 |
4 | 未婚 | 学生 | 18 |
2 | 已婚 | 老师 | 26 |
5 | 已婚 | 上班族 | 47 |
2.5 | 已婚 | 上班族 | 36 |
3.5 | 未婚 | 老师 | 29 |
4 | 已婚 | 学生 | 21 |
从以下的思路理解CART:
分类树?回归树?
分类树的作用是通过一个对象的特征来预测该对象所属的类别,而回归树的目的是根据一个对象的信息预测该对象的属性,并以数值表示。
CART既能是分类树,又能是决策树,如上表所示,如果我们想预测一个人是否已婚,那么构建的CART将是分类树;如果想预测一个人的年龄,那么构建的将是回归树。
分类树和回归树是怎么做决策的?假设我们构建了两棵决策树分别预测用户是否已婚和实际的年龄,如图1和图2所示:
图1 预测婚姻情况决策树 图2 预测年龄的决策树
图1表示一棵分类树,其叶子节点的输出结果为一个实际的类别,在这个例子里是婚姻的情况(已婚或者未婚),选择叶子节点中数量占比最大的类别作为输出的类别;
图2是一棵回归树,预测用户的实际年龄,是一个具体的输出值。怎样得到这个输出值?一般情况下选择使用中值、平均值或者众数进行表示,图2使用节点年龄数据的平均值作为输出值。
CART如何选择分裂的属性?
分裂的目的是为了能够让数据变纯,使决策树输出的结果更接近真实值。那么CART是如何评价节点的纯度呢?如果是分类树,CART采用GINI值衡量节点纯度;如果是回归树,采用样本方差衡量节点纯度。节点越不纯,节点分类或者预测的效果就越差。
GINI值的计算公式:
节点越不纯,GINI值越大。以二分类为例,如果节点的所有数据只有一个类别,则
,如果两类数量相同,则
。
回归方差计算公式:
方差越大,表示该节点的数据越分散,预测的效果就越差。如果一个节点的所有数据都相同,那么方差就为0,此时可以很肯定得认为该节点的输出值;如果节点的数据相差很大,那么输出的值有很大的可能与实际值相差较大。
因此,无论是分类树还是回归树,CART都要选择使子节点的GINI值或者回归方差最小的属性作为分裂的方案。即最小化(分类树):
或者(回归树):
CART如何分裂成一棵二叉树?
节点的分裂分为两种情况,连续型的数据和离散型的数据。
CART对连续型属性的处理与C4.5差不多,通过最小化分裂后的GINI值或者样本方差寻找最优分割点,将节点一分为二,在这里不再叙述,详细请看C4.5。
对于离散型属性,理论上有多少个离散值就应该分裂成多少个节点。但CART是一棵二叉树,每一次分裂只会产生两个节点,怎么办呢?很简单,只要将其中一个离散值独立作为一个节点,其他的离散值生成另外一个节点即可。这种分裂方案有多少个离散值就有多少种划分的方法,举一个简单的例子:如果某离散属性一个有三个离散值X,Y,Z,则该属性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分别计算每种划分方法的基尼值或者样本方差确定最优的方法。
以属性“职业”为例,一共有三个离散值,“学生”、“老师”、“上班族”。该属性有三种划分的方案,分别为{“学生”}、{“老师”、“上班族”},{“老师”}、{“学生”、“上班族”},{“上班族”}、{“学生”、“老师”},分别计算三种划分方案的子节点GINI值或者样本方差,选择最优的划分方法,如下图所示:
第一种划分方法:{“学生”}、{“老师”、“上班族”}
预测是否已婚(分类):
预测年龄(回归):
第二种划分方法:{“老师”}、{“学生”、“上班族”}
预测是否已婚(分类):
预测年龄(回归):
第三种划分方法:{“上班族”}、{“学生”、“老师”}
预测是否已婚(分类):
预测年龄(回归):
综上,如果想预测是否已婚,则选择{“上班族”}、{“学生”、“老师”}的划分方法,如果想预测年龄,则选择{“老师”}、{“学生”、“上班族”}的划分方法。
如何剪枝?
CART采用CCP(代价复杂度)剪枝方法。代价复杂度选择节点表面误差率增益值最小的非叶子节点,删除该非叶子节点的左右子节点,若有多个非叶子节点的表面误差率增益值相同小,则选择非叶子节点中子节点数最多的非叶子节点进行剪枝。
可描述如下:
令决策树的非叶子节点为
。
a)计算所有非叶子节点的表面误差率增益值
b)选择表面误差率增益值
最小的非叶子节点
(若多个非叶子节点具有相同小的表面误差率增益值,选择节点数最多的非叶子节点)。
c)对
进行剪枝
表面误差率增益值的计算公式:
其中:
表示叶子节点的误差代价,
,
为节点的错误率,
为节点数据量的占比;
表示子树的误差代价,
,
为子节点i的错误率,
表示节点i的数据节点占比;
表示子树节点个数。
算例:
下图是其中一颗子树,设决策树的总数据量为40。
该子树的表面误差率增益值可以计算如下:
求出该子树的表面错误覆盖率为 ,只要求出其他子树的表面误差率增益值就可以对决策树进行剪枝。
程序实际以及源代码
流程图:
(1)数据处理
对原始的数据进行数字化处理,并以二维数据的形式存储,每一行表示一条记录,前n-1列表示属性,最后一列表示分类的标签。
如表1的数据可以转化为表2:
表2 初始化后的数据
看电视时间 | 婚姻情况 | 职业 | 年龄 |
---|---|---|---|
3 | 未婚 | 学生 | 12 |
4 | 未婚 | 学生 | 18 |
2 | 已婚 | 老师 | 26 |
5 | 已婚 | 上班族 | 47 |
2.5 | 已婚 | 上班族 | 36 |
3.5 | 未婚 | 老师 | 29 |
4 | 已婚 | 学生 | 21 |
其中,对于“婚姻情况”属性,数字{1,2}分别表示{未婚,已婚 };对于“职业”属性{1,2,3, }分别表示{学生、老师、上班族};
代码如下所示:
static double[][] allData; //存储进行训练的数据
static List<String>[] featureValues; //离散属性对应的离散值
featureValues是链表数组,数组的长度为属性的个数,数组的每个元素为该属性的离散值链表。
(2)两个类:节点类和分裂信息
a)节点类Node
该类表示一个节点,属性包括节点选择的分裂属性、节点的输出类、孩子节点、深度等。注意,与ID3中相比,新增了两个属性:leafWrong和leafNode_Count分别表示叶子节点的总分类误差和叶子节点的个数,主要是为了方便剪枝。
代码语言:javascript复制 1 class Node 2 { 3 /// <summary> 4 /// 每一个节点的分裂值 5 /// </summary> 6 public List<String> features { get; set; } 7 /// <summary> 8 /// 分裂属性的类型{离散、连续} 9 /// </summary> 10 public String feature_Type { get; set; } 11 /// <summary> 12 /// 分裂属性的下标 13 /// </summary> 14 public String SplitFeature { get; set; } 15 //List<int> nums = new List<int>(); //行序号 16 /// <summary> 17 /// 每一个类对应的数目 18 /// </summary> 19 public double[] ClassCount { get; set; } 20 //int[] isUsed = new int[0]; //属性的使用情况 1:已用 2:未用 21 /// <summary> 22 /// 孩子节点 23 /// </summary> 24 public List<Node> childNodes { get; set; } 25 Node Parent = null; 26 /// <summary> 27 /// 该节点占比最大的类别 28 /// </summary> 29 public String finalResult { get; set; } 30 /// <summary> 31 /// 树的深度 32 /// </summary> 33 public int deep { get; set; } 34 /// <summary> 35 /// 最大的类下标 36 /// </summary> 37 public int result { get; set; } 38 /// <summary> 39 /// 子节点误差 40 /// </summary> 41 public int leafWrong { get; set; } 42 /// <summary> 43 /// 子节点数目 44 /// </summary> 45 public int leafNode_Count { get; set; } 46 /// <summary> 47 /// 数据量 48 /// </summary> 49 public int rowCount { get; set; } 50 51 public void setClassCount(double[] count) 52 { 53 this.ClassCount = count; 54 double max = ClassCount[0]; 55 int result = 0; 56 for (int i = 1; i < ClassCount.Length; i ) 57 { 58 if (max < ClassCount[i]) 59 { 60 max = ClassCount[i]; 61 result = i; 62 } 63 } 64 this.result = result; 65 } 66 public double getErrorCount() 67 { 68 return rowCount - ClassCount[result]; 69 } 70 }
树的节点
b)分裂信息类,该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。
代码语言:javascript复制 1 class SplitInfo 2 { 3 /// <summary> 4 /// 分裂的属性下标 5 /// </summary> 6 public int splitIndex { get; set; } 7 /// <summary> 8 /// 数据类型 9 /// </summary> 10 public int type { get; set; } 11 /// <summary> 12 /// 分裂属性的取值 13 /// </summary> 14 public List<String> features { get; set; } 15 /// <summary> 16 /// 各个节点的行坐标链表 17 /// </summary> 18 public List<int>[] temp { get; set; } 19 /// <summary> 20 /// 每个节点各类的数目 21 /// </summary> 22 public double[][] class_Count { get; set; } 23 }
分裂信息
主方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂
其中:
node表示即将进行分裂的节点;
nums表示节点数据对一个的行坐标列表;
isUsed表示到该节点位置所有属性的使用情况;
findBestSplit的这个方法主要有以下几个组成部分:
1)节点分裂停止的判定
节点分裂条件如上文所述,源代码如下:
代码语言:javascript复制 1 public static bool ifEnd(Node node, double shang,int[] isUsed) 2 { 3 try 4 { 5 double[] count = node.ClassCount; 6 int rowCount = node.rowCount; 7 int maxResult = 0; 8 double maxRate = 0; 9 #region 数达到某一深度 10 int deep = node.deep; 11 if (deep >= 10) 12 { 13 maxResult = node.result 1; 14 node.feature_Type="result"; 15 node.features=new List<String>() { maxResult "" 16 17 }; 18 node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]); 19 node.leafNode_Count=1; 20 return true; 21 } 22 #endregion 23 #region 纯度(其实跟后面的有点重了,记得要修改) 24 //maxResult = 1; 25 //for (int i = 1; i < count.Length; i ) 26 //{ 27 // if (count[i] / rowCount >= 0.95) 28 // { 29 // node.feature_Type="result"; 30 // node.features=new List<String> { "" (i 31 32 1) }; 33 // node.leafNode_Count=1; 34 // node.leafWrong=rowCount - Convert.ToInt32 35 36 (count[i]); 37 // return true; 38 // } 39 //} 40 #endregion 41 #region 熵为0 42 if (shang == 0) 43 { 44 maxRate = count[0] / rowCount; 45 maxResult = 1; 46 for (int i = 1; i < count.Length; i ) 47 { 48 if (count[i] / rowCount >= maxRate) 49 { 50 maxRate = count[i] / rowCount; 51 maxResult = i 1; 52 } 53 } 54 node.feature_Type="result"; 55 node.features=new List<String> { maxResult "" 56 57 }; 58 node.leafWrong=rowCount - Convert.ToInt32(count 59 60 [maxResult - 1]); 61 node.leafNode_Count=1; 62 return true; 63 } 64 #endregion 65 #region 属性已经分完 66 //int[] isUsed = node.getUsed(); 67 bool flag = true; 68 for (int i = 0; i < isUsed.Length - 1; i ) 69 { 70 if (isUsed[i] == 0) 71 { 72 flag = false; 73 break; 74 } 75 } 76 if (flag) 77 { 78 maxRate = count[0] / rowCount; 79 maxResult = 1; 80 for (int i = 1; i < count.Length; i ) 81 { 82 if (count[i] / rowCount >= maxRate) 83 { 84 maxRate = count[i] / rowCount; 85 maxResult = i 1; 86 } 87 } 88 node.feature_Type=("result"); 89 node.features=(new List<String> { "" 90 91 (maxResult) }); 92 node.leafWrong=(rowCount - Convert.ToInt32(count 93 94 [maxResult - 1])); 95 node.leafNode_Count=(1); 96 return true; 97 } 98 #endregion 99 #region 几点数少于100 100 if (rowCount < Limit_Node) 101 { 102 maxRate = count[0] / rowCount; 103 maxResult = 1; 104 for (int i = 1; i < count.Length; i ) 105 { 106 if (count[i] / rowCount >= maxRate) 107 { 108 maxRate = count[i] / rowCount; 109 maxResult = i 1; 110 } 111 } 112 node.feature_Type="result"; 113 node.features=new List<String> { "" (maxResult) 114 115 }; 116 node.leafWrong=rowCount - Convert.ToInt32(count 117 118 [maxResult - 1]); 119 node.leafNode_Count=1; 120 return true; 121 } 122 #endregion 123 return false; 124 } 125 catch (Exception e) 126 { 127 return false; 128 } 129 }
停止分裂的条件
2)寻找最优的分裂属性
寻找最优的分裂属性需要计算每一个分裂属性分裂后的GINI值或者样本方差,计算公式上文已给出,其中GINI值的计算代码如下:
代码语言:javascript复制1 public static double getGini(double[] counts, int countAll) 2 { 3 double Gini = 1; 4 for (int i = 0; i < counts.Length; i ) 5 { 6 Gini = Gini - Math.Pow(counts[i] / countAll, 2); 7 } 8 return Gini; 9 }
GINI值计算
3)进行分裂,同时对子节点进行迭代处理
其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。
findBestSplit源代码:
代码语言:javascript复制/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释
* 1 public static Node findBestSplit(Node node,List<int> nums,int[] isUsed) 2 { 3 try 4 { 5 //判断是否继续分裂 6 double totalShang = getGini(node.ClassCount, node.rowCount); 7 if (ifEnd(node, totalShang, isUsed)) 8 { 9 return node; 10 } 11 #region 变量声明 12 SplitInfo info = new SplitInfo(); 13 info.initial(); 14 int RowCount = nums.Count; //样本总数 15 double jubuMax = 1; //局部最大熵 16 int splitPoint = 0; //分裂的点 17 double splitValue = 0; //分裂的值 18 #endregion 19 for (int i = 0; i < isUsed.Length - 1; i ) 20 { 21 if (isUsed[i] == 1) 22 { 23 continue; 24 } 25 #region 离散变量 26 if (type[i] == 0) 27 { 28 double[][] allCount = new double[allNum[i]][]; 29 for (int j = 0; j < allCount.Length; j ) 30 { 31 allCount[j] = new double[classCount]; 32 } 33 int[] countAllFeature = new int[allNum[i]]; 34 List<int>[] temp = new List<int>[allNum[i]]; 35 double[] allClassCount = node.ClassCount; //所有类别的数量 36 for (int j = 0; j < temp.Length; j ) 37 { 38 temp[j] = new List<int>(); 39 } 40 for (int j = 0; j < nums.Count; j ) 41 { 42 int index = Convert.ToInt32(allData[nums[j]][i]); 43 temp[index - 1].Add(nums[j]); 44 countAllFeature[index - 1] ; 45 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1] ; 46 } 47 double allShang = 1; 48 int choose = 0; 49 50 double[][] jubuCount = new double[2][]; 51 for (int k = 0; k < allCount.Length; k ) 52 { 53 if (temp[k].Count == 0) 54 continue; 55 double JubuShang = 0; 56 double[][] tempCount = new double[2][]; 57 tempCount[0] = allCount[k]; 58 tempCount[1] = new double[allCount[0].Length]; 59 for (int j = 0; j < tempCount[1].Length; j ) 60 { 61 tempCount[1][j] = allClassCount[j] - allCount[k][j]; 62 } 63 JubuShang = JubuShang getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount; 64 int nodecount = RowCount - countAllFeature[k]; 65 JubuShang = JubuShang getGini(tempCount[1], nodecount) * nodecount / RowCount; 66 if (JubuShang < allShang) 67 { 68 allShang = JubuShang; 69 jubuCount = tempCount; 70 choose = k; 71 } 72 } 73 if (allShang < jubuMax) 74 { 75 info.type = 0; 76 jubuMax = allShang; 77 info.class_Count = jubuCount; 78 info.temp[0] = temp[choose]; 79 info.temp[1] = new List<int>(); 80 info.features = new List<string>(); 81 info.features.Add((choose 1) ""); 82 info.features.Add(""); 83 for (int j = 0; j < temp.Length; j ) 84 { 85 if (j == choose) 86 continue; 87 for (int k = 0; k < temp[j].Count; k ) 88 { 89 info.temp[1].Add(temp[j][k]); 90 } 91 if (temp[j].Count != 0) 92 { 93 info.features[1] = info.features[1] (j 1) ","; 94 } 95 } 96 info.splitIndex = i; 97 } 98 } 99 #endregion 100 #region 连续变量 101 else 102 { 103 double[] leftCunt = new double[classCount]; 104 105 //做节点各个类别的数量 106 double[] rightCount = new double[classCount]; 107 108 //右节点各个类别的数量 109 double[] count1 = new double[classCount]; 110 111 //子集1的统计量 112 double[] count2 = new double 113 114 [node.ClassCount.Length]; //子集2的统计量 115 for (int j = 0; j < node.ClassCount.Length; 116 117 j ) 118 { 119 count2[j] = node.ClassCount[j]; 120 } 121 int all1 = 0; 122 123 //子集1的样本量 124 int all2 = nums.Count; 125 126 //子集2的样本量 127 double lastValue = 0; 128 129 //上一个记录的类别 130 double currentValue = 0; 131 132 //当前类别 133 double lastPoint = 0; 134 135 //上一个点的值 136 double currentPoint = 0; 137 138 //当前点的值 139 double[] values = new double[nums.Count]; 140 for (int j = 0; j < values.Length; j ) 141 { 142 values[j] = allData[nums[j]][i]; 143 } 144 QSort(values, nums, 0, nums.Count - 1); 145 double lianxuMax = 1; 146 147 //连续型属性的最大熵 148 #region 寻找最佳的分割点 149 for (int j = 0; j < nums.Count - 1; j ) 150 { 151 currentValue = allData[nums[j]][lieshu - 152 153 1]; 154 currentPoint = (allData[nums[j]][i]); 155 if (j == 0) 156 { 157 lastValue = currentValue; 158 lastPoint = currentPoint; 159 } 160 if (currentValue != lastValue && 161 162 currentPoint != lastPoint) 163 { 164 double shang1 = getGini(count1, 165 166 all1); 167 double shang2 = getGini(count2, 168 169 all2); 170 double allShang = shang1 * all1 / 171 172 (all1 all2) shang2 * all2 / (all1 all2); 173 //allShang = (totalShang - allShang); 174 if (lianxuMax > allShang) 175 { 176 lianxuMax = allShang; 177 for (int k = 0; k < 178 179 count1.Length; k ) 180 { 181 leftCunt[k] = count1[k]; 182 rightCount[k] = count2[k]; 183 } 184 splitPoint = j; 185 splitValue = (currentPoint 186 187 lastPoint) / 2; 188 } 189 } 190 all1 ; 191 count1[Convert.ToInt32(currentValue) - 192 193 1] ; 194 count2[Convert.ToInt32(currentValue) - 195 196 1]--; 197 all2--; 198 lastValue = currentValue; 199 lastPoint = currentPoint; 200 } 201 #endregion 202 #region 如果超过了局部值,重设 203 if (lianxuMax < jubuMax) 204 { 205 info.type = 1; 206 info.splitIndex = i; 207 info.features=new List<string>() 208 209 {splitValue ""}; 210 //finalPoint = splitPoint; 211 jubuMax = lianxuMax; 212 info.temp[0] = new List<int>(); 213 info.temp[1] = new List<int>(); 214 for (int k = 0; k < splitPoint; k ) 215 { 216 info.temp[0].Add(nums[k]); 217 } 218 for (int k = splitPoint; k < nums.Count; 219 220 k ) 221 { 222 info.temp[1].Add(nums[k]); 223 } 224 info.class_Count[0] = new double 225 226 [leftCunt.Length]; 227 info.class_Count[1] = new double 228 229 [leftCunt.Length]; 230 for (int k = 0; k < leftCunt.Length; k ) 231 { 232 info.class_Count[0][k] = leftCunt[k]; 233 info.class_Count[1][k] = rightCount 234 235 [k]; 236 } 237 } 238 #endregion 239 } 240 #endregion 241 } 242 #region 没有寻找到最佳的分裂点,则设置为叶节点 243 if (info.splitIndex == -1) 244 { 245 double[] finalCount = node.ClassCount; 246 double max = finalCount[0]; 247 int result = 1; 248 for (int i = 1; i < finalCount.Length; i ) 249 { 250 if (finalCount[i] > max) 251 { 252 max = finalCount[i]; 253 result = (i 1); 254 } 255 } 256 node.feature_Type="result"; 257 node.features=new List<String> { "" result }; 258 return node; 259 } 260 #endregion 261 #region 分裂 262 int deep = node.deep; 263 node.SplitFeature = ("" info.splitIndex); 264 List<Node> childNode = new List<Node>(); 265 int[][] used = new int[2][]; 266 used[0] = new int[isUsed.Length]; 267 used[1] = new int[isUsed.Length]; 268 for (int i = 0; i < isUsed.Length; i ) 269 { 270 used[0][i] = isUsed[i]; 271 used[1][i] = isUsed[i]; 272 } 273 if (info.type == 0) 274 { 275 used[0][info.splitIndex] = 1; 276 node.feature_Type = ("离散"); 277 } 278 else 279 { 280 //used[info.splitIndex] = 0; 281 node.feature_Type = ("连续"); 282 } 283 List<int>[] rowIndex = info.temp; 284 List<String> features = info.features; 285 Node node1 = new Node(); 286 Node node2 = new Node(); 287 node1.setClassCount(info.class_Count[0]); 288 node2.setClassCount(info.class_Count[1]); 289 node1.rowCount = info.temp[0].Count; 290 node2.rowCount = info.temp[1].Count; 291 node1.deep = deep 1; 292 node2.deep = deep 1; 293 node1 = findBestSplit(node1, info.temp[0],used[0]); 294 node2 = findBestSplit(node2, info.temp[1], used[1]); 295 node.leafNode_Count = (node1.leafNode_Count 296 297 node2.leafNode_Count); 298 node.leafWrong = (node1.leafWrong node2.leafWrong); 299 node.features = (features); 300 childNode.Add(node1); 301 childNode.Add(node2); 302 node.childNodes = childNode; 303 #endregion 304 return node; 305 } 306 catch (Exception e) 307 { 308 Console.WriteLine(e.StackTrace); 309 return node; 310 } 311 }
*/
节点选择属性和分裂
(4)剪枝
代价复杂度剪枝方法(CCP):
代码语言:javascript复制 1 public static void getSeries(Node node) 2 { 3 Stack<Node> nodeStack = new Stack<Node>(); 4 if (node != null) 5 { 6 nodeStack.Push(node); 7 } 8 if (node.feature_Type == "result") 9 return; 10 List<Node> childs = node.childNodes; 11 for (int i = 0; i < childs.Count; i ) 12 { 13 getSeries(node); 14 } 15 }
CCP代价复杂度剪枝
CART全部核心代码:
代码语言:javascript复制/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释
* 1 /// <summary> 2 /// 判断是否还需要分裂 3 /// </summary> 4 /// <param name="node"></param> 5 /// <returns></returns> 6 public static bool ifEnd(Node node, double shang,int[] isUsed) 7 { 8 try 9 { 10 double[] count = node.ClassCount; 11 int rowCount = node.rowCount; 12 int maxResult = 0; 13 double maxRate = 0; 14 #region 数达到某一深度 15 int deep = node.deep; 16 if (deep >= 10) 17 { 18 maxResult = node.result 1; 19 node.feature_Type="result"; 20 node.features=new List<String>() { maxResult "" 21 22 }; 23 node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]); 24 node.leafNode_Count=1; 25 return true; 26 } 27 #endregion 28 #region 纯度(其实跟后面的有点重了,记得要修改) 29 //maxResult = 1; 30 //for (int i = 1; i < count.Length; i ) 31 //{ 32 // if (count[i] / rowCount >= 0.95) 33 // { 34 // node.feature_Type="result"; 35 // node.features=new List<String> { "" (i 36 37 1) }; 38 // node.leafNode_Count=1; 39 // node.leafWrong=rowCount - Convert.ToInt32 40 41 (count[i]); 42 // return true; 43 // } 44 //} 45 #endregion 46 #region 熵为0 47 if (shang == 0) 48 { 49 maxRate = count[0] / rowCount; 50 maxResult = 1; 51 for (int i = 1; i < count.Length; i ) 52 { 53 if (count[i] / rowCount >= maxRate) 54 { 55 maxRate = count[i] / rowCount; 56 maxResult = i 1; 57 } 58 } 59 node.feature_Type="result"; 60 node.features=new List<String> { maxResult "" 61 62 }; 63 node.leafWrong=rowCount - Convert.ToInt32(count 64 65 [maxResult - 1]); 66 node.leafNode_Count=1; 67 return true; 68 } 69 #endregion 70 #region 属性已经分完 71 //int[] isUsed = node.getUsed(); 72 bool flag = true; 73 for (int i = 0; i < isUsed.Length - 1; i ) 74 { 75 if (isUsed[i] == 0) 76 { 77 flag = false; 78 break; 79 } 80 } 81 if (flag) 82 { 83 maxRate = count[0] / rowCount; 84 maxResult = 1; 85 for (int i = 1; i < count.Length; i ) 86 { 87 if (count[i] / rowCount >= maxRate) 88 { 89 maxRate = count[i] / rowCount; 90 maxResult = i 1; 91 } 92 } 93 node.feature_Type=("result"); 94 node.features=(new List<String> { "" 95 96 (maxResult) }); 97 node.leafWrong=(rowCount - Convert.ToInt32(count 98 99 [maxResult - 1])); 100 node.leafNode_Count=(1); 101 return true; 102 } 103 #endregion 104 #region 几点数少于100 105 if (rowCount < Limit_Node) 106 { 107 maxRate = count[0] / rowCount; 108 maxResult = 1; 109 for (int i = 1; i < count.Length; i ) 110 { 111 if (count[i] / rowCount >= maxRate) 112 { 113 maxRate = count[i] / rowCount; 114 maxResult = i 1; 115 } 116 } 117 node.feature_Type="result"; 118 node.features=new List<String> { "" (maxResult) 119 120 }; 121 node.leafWrong=rowCount - Convert.ToInt32(count 122 123 [maxResult - 1]); 124 node.leafNode_Count=1; 125 return true; 126 } 127 #endregion 128 return false; 129 } 130 catch (Exception e) 131 { 132 return false; 133 } 134 } 135 #region 排序算法 136 public static void InsertSort(double[] values, List<int> arr, 137 138 int StartIndex, int endIndex) 139 { 140 for (int i = StartIndex 1; i <= endIndex; i ) 141 { 142 int key = arr[i]; 143 double init = values[i]; 144 int j = i - 1; 145 while (j >= StartIndex && values[j] > init) 146 { 147 arr[j 1] = arr[j]; 148 values[j 1] = values[j]; 149 j--; 150 } 151 arr[j 1] = key; 152 values[j 1] = init; 153 } 154 } 155 static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high) 156 { 157 int mid = low ((high - low) >> 1);//计算数组中间的元素的下标 158 159 //使用三数取中法选择枢轴 160 if (values[mid] > values[high])//目标: arr[mid] <= arr[high] 161 { 162 swap(values, arr, mid, high); 163 } 164 if (values[low] > values[high])//目标: arr[low] <= arr[high] 165 { 166 swap(values, arr, low, high); 167 } 168 if (values[mid] > values[low]) //目标: arr[low] >= arr[mid] 169 { 170 swap(values, arr, mid, low); 171 } 172 //此时,arr[mid] <= arr[low] <= arr[high] 173 return low; 174 //low的位置上保存这三个位置中间的值 175 //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了 176 } 177 static void swap(double[] values, List<int> arr, int t1, int t2) 178 { 179 double temp = values[t1]; 180 values[t1] = values[t2]; 181 values[t2] = temp; 182 int key = arr[t1]; 183 arr[t1] = arr[t2]; 184 arr[t2] = key; 185 } 186 static void QSort(double[] values, List<int> arr, int low, int high) 187 { 188 int first = low; 189 int last = high; 190 191 int left = low; 192 int right = high; 193 194 int leftLen = 0; 195 int rightLen = 0; 196 197 if (high - low 1 < 10) 198 { 199 InsertSort(values, arr, low, high); 200 return; 201 } 202 203 //一次分割 204 int key = SelectPivotMedianOfThree(values, arr, low, 205 206 high);//使用三数取中法选择枢轴 207 double inti = values[key]; 208 int currentKey = arr[key]; 209 210 while (low < high) 211 { 212 while (high > low && values[high] >= inti) 213 { 214 if (values[high] == inti)//处理相等元素 215 { 216 swap(values, arr, right, high); 217 right--; 218 rightLen ; 219 } 220 high--; 221 } 222 arr[low] = arr[high]; 223 values[low] = values[high]; 224 while (high > low && values[low] <= inti) 225 { 226 if (values[low] == inti) 227 { 228 swap(values, arr, left, low); 229 left ; 230 leftLen ; 231 } 232 low ; 233 } 234 arr[high] = arr[low]; 235 values[high] = values[low]; 236 } 237 arr[low] = currentKey; 238 values[low] = values[key]; 239 //一次快排结束 240 //把与枢轴key相同的元素移到枢轴最终位置周围 241 int i = low - 1; 242 int j = first; 243 while (j < left && values[i] != inti) 244 { 245 swap(values, arr, i, j); 246 i--; 247 j ; 248 } 249 i = low 1; 250 j = last; 251 while (j > right && values[i] != inti) 252 { 253 swap(values, arr, i, j); 254 i ; 255 j--; 256 } 257 QSort(values, arr, first, low - 1 - leftLen); 258 QSort(values, arr, low 1 rightLen, last); 259 } 260 #endregion 261 /// <summary> 262 /// 寻找最佳的分裂点 263 /// </summary> 264 /// <param name="num"></param> 265 /// <param name="node"></param> 266 public static Node findBestSplit(Node node,List<int> nums,int[] isUsed) 267 { 268 try 269 { 270 //判断是否继续分裂 271 double totalShang = getGini(node.ClassCount, node.rowCount); 272 if (ifEnd(node, totalShang, isUsed)) 273 { 274 return node; 275 } 276 #region 变量声明 277 SplitInfo info = new SplitInfo(); 278 info.initial(); 279 int RowCount = nums.Count; //样本总数 280 double jubuMax = 1; //局部最大熵 281 int splitPoint = 0; //分裂的点 282 double splitValue = 0; //分裂的值 283 #endregion 284 for (int i = 0; i < isUsed.Length - 1; i ) 285 { 286 if (isUsed[i] == 1) 287 { 288 continue; 289 } 290 #region 离散变量 291 if (type[i] == 0) 292 { 293 double[][] allCount = new double[allNum[i]][]; 294 for (int j = 0; j < allCount.Length; j ) 295 { 296 allCount[j] = new double[classCount]; 297 } 298 int[] countAllFeature = new int[allNum[i]]; 299 List<int>[] temp = new List<int>[allNum[i]]; 300 double[] allClassCount = node.ClassCount; //所有类别的数量 301 for (int j = 0; j < temp.Length; j ) 302 { 303 temp[j] = new List<int>(); 304 } 305 for (int j = 0; j < nums.Count; j ) 306 { 307 int index = Convert.ToInt32(allData[nums[j]][i]); 308 temp[index - 1].Add(nums[j]); 309 countAllFeature[index - 1] ; 310 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1] ; 311 } 312 double allShang = 1; 313 int choose = 0; 314 315 double[][] jubuCount = new double[2][]; 316 for (int k = 0; k < allCount.Length; k ) 317 { 318 if (temp[k].Count == 0) 319 continue; 320 double JubuShang = 0; 321 double[][] tempCount = new double[2][]; 322 tempCount[0] = allCount[k]; 323 tempCount[1] = new double[allCount[0].Length]; 324 for (int j = 0; j < tempCount[1].Length; j ) 325 { 326 tempCount[1][j] = allClassCount[j] - allCount[k][j]; 327 } 328 JubuShang = JubuShang getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount; 329 int nodecount = RowCount - countAllFeature[k]; 330 JubuShang = JubuShang getGini(tempCount[1], nodecount) * nodecount / RowCount; 331 if (JubuShang < allShang) 332 { 333 allShang = JubuShang; 334 jubuCount = tempCount; 335 choose = k; 336 } 337 } 338 if (allShang < jubuMax) 339 { 340 info.type = 0; 341 jubuMax = allShang; 342 info.class_Count = jubuCount; 343 info.temp[0] = temp[choose]; 344 info.temp[1] = new List<int>(); 345 info.features = new List<string>(); 346 info.features.Add((choose 1) ""); 347 info.features.Add(""); 348 for (int j = 0; j < temp.Length; j ) 349 { 350 if (j == choose) 351 continue; 352 for (int k = 0; k < temp[j].Count; k ) 353 { 354 info.temp[1].Add(temp[j][k]); 355 } 356 if (temp[j].Count != 0) 357 { 358 info.features[1] = info.features[1] (j 1) ","; 359 } 360 } 361 info.splitIndex = i; 362 } 363 } 364 #endregion 365 #region 连续变量 366 else 367 { 368 double[] leftCunt = new double[classCount]; 369 370 //做节点各个类别的数量 371 double[] rightCount = new double[classCount]; 372 373 //右节点各个类别的数量 374 double[] count1 = new double[classCount]; 375 376 //子集1的统计量 377 double[] count2 = new double 378 379 [node.ClassCount.Length]; //子集2的统计量 380 for (int j = 0; j < node.ClassCount.Length; 381 382 j ) 383 { 384 count2[j] = node.ClassCount[j]; 385 } 386 int all1 = 0; 387 388 //子集1的样本量 389 int all2 = nums.Count; 390 391 //子集2的样本量 392 double lastValue = 0; 393 394 //上一个记录的类别 395 double currentValue = 0; 396 397 //当前类别 398 double lastPoint = 0; 399 400 //上一个点的值 401 double currentPoint = 0; 402 403 //当前点的值 404 double[] values = new double[nums.Count]; 405 for (int j = 0; j < values.Length; j ) 406 { 407 values[j] = allData[nums[j]][i]; 408 } 409 QSort(values, nums, 0, nums.Count - 1); 410 double lianxuMax = 1; 411 412 //连续型属性的最大熵 413 #region 寻找最佳的分割点 414 for (int j = 0; j < nums.Count - 1; j ) 415 { 416 currentValue = allData[nums[j]][lieshu - 417 418 1]; 419 currentPoint = (allData[nums[j]][i]); 420 if (j == 0) 421 { 422 lastValue = currentValue; 423 lastPoint = currentPoint; 424 } 425 if (currentValue != lastValue && 426 427 currentPoint != lastPoint) 428 { 429 double shang1 = getGini(count1, 430 431 all1); 432 double shang2 = getGini(count2, 433 434 all2); 435 double allShang = shang1 * all1 / 436 437 (all1 all2) shang2 * all2 / (all1 all2); 438 //allShang = (totalShang - allShang); 439 if (lianxuMax > allShang) 440 { 441 lianxuMax = allShang; 442 for (int k = 0; k < 443 444 count1.Length; k ) 445 { 446 leftCunt[k] = count1[k]; 447 rightCount[k] = count2[k]; 448 } 449 splitPoint = j; 450 splitValue = (currentPoint 451 452 lastPoint) / 2; 453 } 454 } 455 all1 ; 456 count1[Convert.ToInt32(currentValue) - 457 458 1] ; 459 count2[Convert.ToInt32(currentValue) - 460 461 1]--; 462 all2--; 463 lastValue = currentValue; 464 lastPoint = currentPoint; 465 } 466 #endregion 467 #region 如果超过了局部值,重设 468 if (lianxuMax < jubuMax) 469 { 470 info.type = 1; 471 info.splitIndex = i; 472 info.features=new List<string>() 473 474 {splitValue ""}; 475 //finalPoint = splitPoint; 476 jubuMax = lianxuMax; 477 info.temp[0] = new List<int>(); 478 info.temp[1] = new List<int>(); 479 for (int k = 0; k < splitPoint; k ) 480 { 481 info.temp[0].Add(nums[k]); 482 } 483 for (int k = splitPoint; k < nums.Count; 484 485 k ) 486 { 487 info.temp[1].Add(nums[k]); 488 } 489 info.class_Count[0] = new double 490 491 [leftCunt.Length]; 492 info.class_Count[1] = new double 493 494 [leftCunt.Length]; 495 for (int k = 0; k < leftCunt.Length; k ) 496 { 497 info.class_Count[0][k] = leftCunt[k]; 498 info.class_Count[1][k] = rightCount 499 500 [k]; 501 } 502 } 503 #endregion 504 } 505 #endregion 506 } 507 #region 没有寻找到最佳的分裂点,则设置为叶节点 508 if (info.splitIndex == -1) 509 { 510 double[] finalCount = node.ClassCount; 511 double max = finalCount[0]; 512 int result = 1; 513 for (int i = 1; i < finalCount.Length; i ) 514 { 515 if (finalCount[i] > max) 516 { 517 max = finalCount[i]; 518 result = (i 1); 519 } 520 } 521 node.feature_Type="result"; 522 node.features=new List<String> { "" result }; 523 return node; 524 } 525 #endregion 526 #region 分裂 527 int deep = node.deep; 528 node.SplitFeature = ("" info.splitIndex); 529 List<Node> childNode = new List<Node>(); 530 int[][] used = new int[2][]; 531 used[0] = new int[isUsed.Length]; 532 used[1] = new int[isUsed.Length]; 533 for (int i = 0; i < isUsed.Length; i ) 534 { 535 used[0][i] = isUsed[i]; 536 used[1][i] = isUsed[i]; 537 } 538 if (info.type == 0) 539 { 540 used[0][info.splitIndex] = 1; 541 node.feature_Type = ("离散"); 542 } 543 else 544 { 545 //used[info.splitIndex] = 0; 546 node.feature_Type = ("连续"); 547 } 548 List<int>[] rowIndex = info.temp; 549 List<String> features = info.features; 550 Node node1 = new Node(); 551 Node node2 = new Node(); 552 node1.setClassCount(info.class_Count[0]); 553 node2.setClassCount(info.class_Count[1]); 554 node1.rowCount = info.temp[0].Count; 555 node2.rowCount = info.temp[1].Count; 556 node1.deep = deep 1; 557 node2.deep = deep 1; 558 node1 = findBestSplit(node1, info.temp[0],used[0]); 559 node2 = findBestSplit(node2, info.temp[1], used[1]); 560 node.leafNode_Count = (node1.leafNode_Count 561 562 node2.leafNode_Count); 563 node.leafWrong = (node1.leafWrong node2.leafWrong); 564 node.features = (features); 565 childNode.Add(node1); 566 childNode.Add(node2); 567 node.childNodes = childNode; 568 #endregion 569 return node; 570 } 571 catch (Exception e) 572 { 573 Console.WriteLine(e.StackTrace); 574 return node; 575 } 576 } 577 /// <summary> 578 /// GINI值 579 /// </summary> 580 /// <param name="counts"></param> 581 /// <param name="countAll"></param> 582 /// <returns></returns> 583 public static double getGini(double[] counts, int countAll) 584 { 585 double Gini = 1; 586 for (int i = 0; i < counts.Length; i ) 587 { 588 Gini = Gini - Math.Pow(counts[i] / countAll, 2); 589 } 590 return Gini; 591 } 592 #region CCP剪枝 593 public static void getSeries(Node node) 594 { 595 Stack<Node> nodeStack = new Stack<Node>(); 596 if (node != null) 597 { 598 nodeStack.Push(node); 599 } 600 if (node.feature_Type == "result") 601 return; 602 List<Node> childs = node.childNodes; 603 for (int i = 0; i < childs.Count; i ) 604 { 605 getSeries(node); 606 } 607 } 608 /// <summary> 609 /// 遍历剪枝 610 /// </summary> 611 /// <param name="node"></param> 612 public static Node getNode1(Node node, Node nodeCut) 613 { 614 615 //List<Node> childNodes = node.getChild(); 616 //double min = 100000; 617 ////Node nodeCut = new Node(); 618 //double temp = 0; 619 //for (int i = 0; i < childNodes.Count; i ) 620 //{ 621 // if (childNodes[i].getType() != "result") 622 // { 623 // //if (!cutTree(childNodes[i])) 624 // temp = min; 625 // min = cutTree(childNodes[i], min); 626 // if (min < temp) 627 // nodeCut = childNodes[i]; 628 // getNode1(childNodes[i], nodeCut); 629 // } 630 //} 631 //node.setChildNode(childNodes); 632 return null; 633 } 634 /// <summary> 635 /// 对每一个节点剪枝 636 /// </summary> 637 public static double cutTree(Node node, double minA) 638 { 639 int rowCount = node.rowCount; 640 double leaf = node.getErrorCount(); 641 double[] values = getError1(node, 0, 0); 642 double treeWrong = values[0]; 643 double son = values[1]; 644 double rate = (leaf - treeWrong) / (son - 1); 645 if (minA > rate) 646 minA = rate; 647 //double var = Math.Sqrt(treeWrong * (1 - treeWrong / 648 649 rowCount)); 650 //double panbie = treeWrong var - leaf; 651 //if (panbie > 0) 652 //{ 653 // node.setFeatureType("result"); 654 // node.setChildNode(null); 655 // int result = (node.getResult() 1); 656 // node.setFeatures(new List<String>() { "" result 657 658 }); 659 // //return true; 660 //} 661 return minA; 662 } 663 /// <summary> 664 /// 获得子树的错误个数 665 /// </summary> 666 /// <param name="node"></param> 667 /// <returns></returns> 668 public static double[] getError1(Node node, double treeError, 669 670 double son) 671 { 672 if (node.feature_Type == "result") 673 { 674 675 double error = node.getErrorCount(); 676 son ; 677 return new double[] { treeError error, son }; 678 } 679 List<Node> childNode = node.childNodes; 680 for (int i = 0; i < childNode.Count; i ) 681 { 682 double[] values = getError1(childNode[i], treeError, 683 684 son); 685 treeError = values[0]; 686 son = values[1]; 687 } 688 return new double[] { treeError, son }; 689 } 690 #endregion
*/
CART核心代码
总结:
(1)CART是一棵二叉树,每一次分裂会产生两个子节点,对于连续性的数据,直接采用与C4.5相似的处理方法,对于离散型数据,选择最优的两种离散值组合方法。
(2)CART既能是分类数,又能是二叉树。如果是分类树,将选择能够最小化分裂后节点GINI值的分裂属性;如果是回归树,选择能够最小化两个节点样本方差的分裂属性。
(3)CART跟C4.5一样,需要进行剪枝,采用CCP(代价复杂度的剪枝方法)。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/166831.html原文链接:https://javaforall.cn