请用决策树_cart决策树使用什么来选择划分属性

2022-09-16 11:44:15 浏览数 (1)

CART,又名分类回归树,是在ID3的基础上进行优化的决策树,学习CART记住以下几个关键点:

(1)CART既能是分类树,又能是分类树;

(2)当CART是分类树时,采用GINI值作为节点分裂的依据;当CART是回归树时,采用样本的最小方差作为节点分裂的依据;

(3)CART是一棵二叉树。

接下来将以一个实际的例子对CART进行介绍:

            表1 原始数据表

看电视时间

婚姻情况

职业

年龄

3

未婚

学生

12

4

未婚

学生

18

2

已婚

老师

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老师

29

4

已婚

学生

21

从以下的思路理解CART:

分类树?回归树?

分类树的作用是通过一个对象的特征来预测该对象所属的类别,而回归树的目的是根据一个对象的信息预测该对象的属性,并以数值表示。

CART既能是分类树,又能是决策树,如上表所示,如果我们想预测一个人是否已婚,那么构建的CART将是分类树;如果想预测一个人的年龄,那么构建的将是回归树。

分类树和回归树是怎么做决策的?假设我们构建了两棵决策树分别预测用户是否已婚和实际的年龄,如图1和图2所示:

              图1 预测婚姻情况决策树 图2 预测年龄的决策树

图1表示一棵分类树,其叶子节点的输出结果为一个实际的类别,在这个例子里是婚姻的情况(已婚或者未婚),选择叶子节点中数量占比最大的类别作为输出的类别;

图2是一棵回归树,预测用户的实际年龄,是一个具体的输出值。怎样得到这个输出值?一般情况下选择使用中值、平均值或者众数进行表示,图2使用节点年龄数据的平均值作为输出值。

CART如何选择分裂的属性?

分裂的目的是为了能够让数据变纯,使决策树输出的结果更接近真实值。那么CART是如何评价节点的纯度呢?如果是分类树,CART采用GINI值衡量节点纯度;如果是回归树,采用样本方差衡量节点纯度。节点越不纯,节点分类或者预测的效果就越差。

GINI值的计算公式:

节点越不纯,GINI值越大。以二分类为例,如果节点的所有数据只有一个类别,则

,如果两类数量相同,则

回归方差计算公式:

方差越大,表示该节点的数据越分散,预测的效果就越差。如果一个节点的所有数据都相同,那么方差就为0,此时可以很肯定得认为该节点的输出值;如果节点的数据相差很大,那么输出的值有很大的可能与实际值相差较大。

因此,无论是分类树还是回归树,CART都要选择使子节点的GINI值或者回归方差最小的属性作为分裂的方案。即最小化(分类树):

或者(回归树):

CART如何分裂成一棵二叉树?

节点的分裂分为两种情况,连续型的数据和离散型的数据。

CART对连续型属性的处理与C4.5差不多,通过最小化分裂后的GINI值或者样本方差寻找最优分割点,将节点一分为二,在这里不再叙述,详细请看C4.5

对于离散型属性,理论上有多少个离散值就应该分裂成多少个节点。但CART是一棵二叉树,每一次分裂只会产生两个节点,怎么办呢?很简单,只要将其中一个离散值独立作为一个节点,其他的离散值生成另外一个节点即可。这种分裂方案有多少个离散值就有多少种划分的方法,举一个简单的例子:如果某离散属性一个有三个离散值X,Y,Z,则该属性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分别计算每种划分方法的基尼值或者样本方差确定最优的方法。

以属性“职业”为例,一共有三个离散值,“学生”、“老师”、“上班族”。该属性有三种划分的方案,分别为{“学生”}、{“老师”、“上班族”},{“老师”}、{“学生”、“上班族”},{“上班族”}、{“学生”、“老师”},分别计算三种划分方案的子节点GINI值或者样本方差,选择最优的划分方法,如下图所示:

第一种划分方法:{“学生”}、{“老师”、“上班族”}

预测是否已婚(分类):

预测年龄(回归):

第二种划分方法:{“老师”}、{“学生”、“上班族”}

预测是否已婚(分类):

预测年龄(回归):

第三种划分方法:{“上班族”}、{“学生”、“老师”}

预测是否已婚(分类):

预测年龄(回归):

综上,如果想预测是否已婚,则选择{“上班族”}、{“学生”、“老师”}的划分方法,如果想预测年龄,则选择{“老师”}、{“学生”、“上班族”}的划分方法。

如何剪枝?

CART采用CCP(代价复杂度)剪枝方法。代价复杂度选择节点表面误差率增益值最小的非叶子节点,删除该非叶子节点的左右子节点,若有多个非叶子节点的表面误差率增益值相同小,则选择非叶子节点中子节点数最多的非叶子节点进行剪枝。

可描述如下:

令决策树的非叶子节点为

a)计算所有非叶子节点的表面误差率增益值

b)选择表面误差率增益值

最小的非叶子节点

(若多个非叶子节点具有相同小的表面误差率增益值,选择节点数最多的非叶子节点)。

c)对

进行剪枝

表面误差率增益值的计算公式:

其中:

表示叶子节点的误差代价,

为节点的错误率,

为节点数据量的占比;

表示子树的误差代价,

为子节点i的错误率,

表示节点i的数据节点占比;

表示子树节点个数。

算例:

下图是其中一颗子树,设决策树的总数据量为40。

该子树的表面误差率增益值可以计算如下:

求出该子树的表面错误覆盖率为 ,只要求出其他子树的表面误差率增益值就可以对决策树进行剪枝。

程序实际以及源代码

流程图:

(1)数据处理

对原始的数据进行数字化处理,并以二维数据的形式存储,每一行表示一条记录,前n-1列表示属性,最后一列表示分类的标签。

如表1的数据可以转化为表2:

表2 初始化后的数据

看电视时间

婚姻情况

职业

年龄

3

未婚

学生

12

4

未婚

学生

18

2

已婚

老师

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老师

29

4

已婚

学生

21

其中,对于“婚姻情况”属性,数字{1,2}分别表示{未婚,已婚 };对于“职业”属性{1,2,3, }分别表示{学生、老师、上班族};

代码如下所示:

static double[][] allData; //存储进行训练的数据

    static List<String>[] featureValues; //离散属性对应的离散值

featureValues是链表数组,数组的长度为属性的个数,数组的每个元素为该属性的离散值链表。

(2)两个类:节点类和分裂信息

a)节点类Node

该类表示一个节点,属性包括节点选择的分裂属性、节点的输出类、孩子节点、深度等。注意,与ID3中相比,新增了两个属性:leafWrong和leafNode_Count分别表示叶子节点的总分类误差和叶子节点的个数,主要是为了方便剪枝。

代码语言:javascript复制
 1 class Node  2 {  3 /// <summary>  4 /// 每一个节点的分裂值  5 /// </summary>  6 public List<String> features { get; set; }  7 /// <summary>  8 /// 分裂属性的类型{离散、连续}  9 /// </summary> 10 public String feature_Type { get; set; } 11 /// <summary> 12 /// 分裂属性的下标 13 /// </summary> 14 public String SplitFeature { get; set; } 15 //List<int> nums = new List<int>(); //行序号 16 /// <summary> 17 /// 每一个类对应的数目 18 /// </summary> 19 public double[] ClassCount { get; set; } 20 //int[] isUsed = new int[0]; //属性的使用情况 1:已用 2:未用 21 /// <summary> 22 /// 孩子节点 23 /// </summary> 24 public List<Node> childNodes { get; set; } 25 Node Parent = null; 26 /// <summary> 27 /// 该节点占比最大的类别 28 /// </summary> 29 public String finalResult { get; set; } 30 /// <summary> 31 /// 树的深度 32 /// </summary> 33 public int deep { get; set; } 34 /// <summary> 35 /// 最大的类下标 36 /// </summary> 37 public int result { get; set; } 38 /// <summary> 39 /// 子节点误差 40 /// </summary> 41 public int leafWrong { get; set; } 42 /// <summary> 43 /// 子节点数目 44 /// </summary> 45 public int leafNode_Count { get; set; } 46 /// <summary> 47 /// 数据量 48 /// </summary> 49 public int rowCount { get; set; } 50 51 public void setClassCount(double[] count) 52  { 53 this.ClassCount = count; 54 double max = ClassCount[0]; 55 int result = 0; 56 for (int i = 1; i < ClassCount.Length; i  ) 57  { 58 if (max < ClassCount[i]) 59  { 60 max = ClassCount[i]; 61 result = i; 62  } 63  } 64 this.result = result; 65  } 66 public double getErrorCount() 67  { 68 return rowCount - ClassCount[result]; 69  } 70 }

树的节点

b)分裂信息类,该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

代码语言:javascript复制
 1 class SplitInfo  2  {  3 /// <summary>  4 /// 分裂的属性下标  5 /// </summary>  6 public int splitIndex { get; set; }  7 /// <summary>  8 /// 数据类型  9 /// </summary> 10 public int type { get; set; } 11 /// <summary> 12 /// 分裂属性的取值 13 /// </summary> 14 public List<String> features { get; set; } 15 /// <summary> 16 /// 各个节点的行坐标链表 17 /// </summary> 18 public List<int>[] temp { get; set; } 19 /// <summary> 20 /// 每个节点各类的数目 21 /// </summary> 22 public double[][] class_Count { get; set; } 23 }

分裂信息

主方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂

其中:

node表示即将进行分裂的节点;

nums表示节点数据对一个的行坐标列表;

isUsed表示到该节点位置所有属性的使用情况;

findBestSplit的这个方法主要有以下几个组成部分:

1)节点分裂停止的判定

节点分裂条件如上文所述,源代码如下:

代码语言:javascript复制
 1 public static bool ifEnd(Node node, double shang,int[] isUsed)  2  {  3 try  4  {  5 double[] count = node.ClassCount;  6 int rowCount = node.rowCount;  7 int maxResult = 0;  8 double maxRate = 0;  9 #region 数达到某一深度  10 int deep = node.deep;  11 if (deep >= 10)  12  {  13 maxResult = node.result   1;  14 node.feature_Type="result";  15 node.features=new List<String>() { maxResult   ""  16  17 };  18 node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);  19 node.leafNode_Count=1;  20 return true;  21  }  22 #endregion  23 #region 纯度(其实跟后面的有点重了,记得要修改)  24 //maxResult = 1;  25 //for (int i = 1; i < count.Length; i  )  26 //{  27 // if (count[i] / rowCount >= 0.95)  28 // {  29 // node.feature_Type="result";  30 // node.features=new List<String> { ""   (i     31  32 1) };  33 // node.leafNode_Count=1;  34 // node.leafWrong=rowCount - Convert.ToInt32  35  36 (count[i]);  37 // return true;  38 // }  39 //}  40 #endregion  41 #region 熵为0  42 if (shang == 0)  43  {  44 maxRate = count[0] / rowCount;  45 maxResult = 1;  46 for (int i = 1; i < count.Length; i  )  47  {  48 if (count[i] / rowCount >= maxRate)  49  {  50 maxRate = count[i] / rowCount;  51 maxResult = i   1;  52  }  53  }  54 node.feature_Type="result";  55 node.features=new List<String> { maxResult   ""  56  57 };  58 node.leafWrong=rowCount - Convert.ToInt32(count  59  60 [maxResult - 1]);  61 node.leafNode_Count=1;  62 return true;  63  }  64 #endregion  65 #region 属性已经分完  66 //int[] isUsed = node.getUsed();  67 bool flag = true;  68 for (int i = 0; i < isUsed.Length - 1; i  )  69  {  70 if (isUsed[i] == 0)  71  {  72 flag = false;  73 break;  74  }  75  }  76 if (flag)  77  {  78 maxRate = count[0] / rowCount;  79 maxResult = 1;  80 for (int i = 1; i < count.Length; i  )  81  {  82 if (count[i] / rowCount >= maxRate)  83  {  84 maxRate = count[i] / rowCount;  85 maxResult = i   1;  86  }  87  }  88 node.feature_Type=("result");  89 node.features=(new List<String> { ""    90  91 (maxResult) });  92 node.leafWrong=(rowCount - Convert.ToInt32(count  93  94 [maxResult - 1]));  95 node.leafNode_Count=(1);  96 return true;  97  }  98 #endregion  99 #region 几点数少于100 100 if (rowCount < Limit_Node) 101  { 102 maxRate = count[0] / rowCount; 103 maxResult = 1; 104 for (int i = 1; i < count.Length; i  ) 105  { 106 if (count[i] / rowCount >= maxRate) 107  { 108 maxRate = count[i] / rowCount; 109 maxResult = i   1; 110  } 111  } 112 node.feature_Type="result"; 113 node.features=new List<String> { ""   (maxResult) 114 115 }; 116 node.leafWrong=rowCount - Convert.ToInt32(count 117 118 [maxResult - 1]); 119 node.leafNode_Count=1; 120 return true; 121  } 122 #endregion 123 return false; 124  } 125 catch (Exception e) 126  { 127 return false; 128  } 129 }

停止分裂的条件

2)寻找最优的分裂属性

寻找最优的分裂属性需要计算每一个分裂属性分裂后的GINI值或者样本方差,计算公式上文已给出,其中GINI值的计算代码如下:

代码语言:javascript复制
1 public static double getGini(double[] counts, int countAll) 2  { 3 double Gini = 1; 4 for (int i = 0; i < counts.Length; i  ) 5  { 6 Gini = Gini - Math.Pow(counts[i] / countAll, 2); 7  } 8 return Gini; 9 }

GINI值计算

3)进行分裂,同时对子节点进行迭代处理

其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。

findBestSplit源代码:

代码语言:javascript复制
/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释 
* 1 public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)  2  {  3 try  4  {  5 //判断是否继续分裂  6 double totalShang = getGini(node.ClassCount, node.rowCount);  7 if (ifEnd(node, totalShang, isUsed))  8  {  9 return node;  10  }  11 #region 变量声明  12 SplitInfo info = new SplitInfo();  13  info.initial();  14 int RowCount = nums.Count; //样本总数  15 double jubuMax = 1; //局部最大熵  16 int splitPoint = 0; //分裂的点  17 double splitValue = 0; //分裂的值  18 #endregion  19 for (int i = 0; i < isUsed.Length - 1; i  )  20  {  21 if (isUsed[i] == 1)  22  {  23 continue;  24  }  25 #region 离散变量  26 if (type[i] == 0)  27  {  28 double[][] allCount = new double[allNum[i]][];  29 for (int j = 0; j < allCount.Length; j  )  30  {  31 allCount[j] = new double[classCount];  32  }  33 int[] countAllFeature = new int[allNum[i]];  34 List<int>[] temp = new List<int>[allNum[i]];  35 double[] allClassCount = node.ClassCount; //所有类别的数量  36 for (int j = 0; j < temp.Length; j  )  37  {  38 temp[j] = new List<int>();  39  }  40 for (int j = 0; j < nums.Count; j  )  41  {  42 int index = Convert.ToInt32(allData[nums[j]][i]);  43 temp[index - 1].Add(nums[j]);  44 countAllFeature[index - 1]  ;  45 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]  ;  46  }  47 double allShang = 1;  48 int choose = 0;  49  50 double[][] jubuCount = new double[2][];  51 for (int k = 0; k < allCount.Length; k  )  52  {  53 if (temp[k].Count == 0)  54 continue;  55 double JubuShang = 0;  56 double[][] tempCount = new double[2][];  57 tempCount[0] = allCount[k];  58 tempCount[1] = new double[allCount[0].Length];  59 for (int j = 0; j < tempCount[1].Length; j  )  60  {  61 tempCount[1][j] = allClassCount[j] - allCount[k][j];  62  }  63 JubuShang = JubuShang   getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;  64 int nodecount = RowCount - countAllFeature[k];  65 JubuShang = JubuShang   getGini(tempCount[1], nodecount) * nodecount / RowCount;  66 if (JubuShang < allShang)  67  {  68 allShang = JubuShang;  69 jubuCount = tempCount;  70 choose = k;  71  }  72  }  73 if (allShang < jubuMax)  74  {  75 info.type = 0;  76 jubuMax = allShang;  77 info.class_Count = jubuCount;  78 info.temp[0] = temp[choose];  79 info.temp[1] = new List<int>();  80 info.features = new List<string>();  81 info.features.Add((choose   1)   "");  82 info.features.Add("");  83 for (int j = 0; j < temp.Length; j  )  84  {  85 if (j == choose)  86 continue;  87 for (int k = 0; k < temp[j].Count; k  )  88  {  89 info.temp[1].Add(temp[j][k]);  90  }  91 if (temp[j].Count != 0)  92  {  93 info.features[1] = info.features[1]   (j   1)   ",";  94  }  95  }  96 info.splitIndex = i;  97  }  98  }  99 #endregion 100 #region 连续变量 101 else 102  { 103 double[] leftCunt = new double[classCount]; 104 105 //做节点各个类别的数量 106 double[] rightCount = new double[classCount]; 107 108 //右节点各个类别的数量 109 double[] count1 = new double[classCount]; 110 111 //子集1的统计量 112 double[] count2 = new double 113 114 [node.ClassCount.Length]; //子集2的统计量 115 for (int j = 0; j < node.ClassCount.Length; 116 117 j  ) 118  { 119 count2[j] = node.ClassCount[j]; 120  } 121 int all1 = 0; 122 123 //子集1的样本量 124 int all2 = nums.Count; 125 126 //子集2的样本量 127 double lastValue = 0; 128 129 //上一个记录的类别 130 double currentValue = 0; 131 132 //当前类别 133 double lastPoint = 0; 134 135 //上一个点的值 136 double currentPoint = 0; 137 138 //当前点的值 139 double[] values = new double[nums.Count]; 140 for (int j = 0; j < values.Length; j  ) 141  { 142 values[j] = allData[nums[j]][i]; 143  } 144 QSort(values, nums, 0, nums.Count - 1); 145 double lianxuMax = 1; 146 147 //连续型属性的最大熵 148 #region 寻找最佳的分割点 149 for (int j = 0; j < nums.Count - 1; j  ) 150  { 151 currentValue = allData[nums[j]][lieshu - 152 153 1]; 154 currentPoint = (allData[nums[j]][i]); 155 if (j == 0) 156  { 157 lastValue = currentValue; 158 lastPoint = currentPoint; 159  } 160 if (currentValue != lastValue && 161 162 currentPoint != lastPoint) 163  { 164 double shang1 = getGini(count1, 165 166 all1); 167 double shang2 = getGini(count2, 168 169 all2); 170 double allShang = shang1 * all1 / 171 172 (all1   all2)   shang2 * all2 / (all1   all2); 173 //allShang = (totalShang - allShang); 174 if (lianxuMax > allShang) 175  { 176 lianxuMax = allShang; 177 for (int k = 0; k < 178 179 count1.Length; k  ) 180  { 181 leftCunt[k] = count1[k]; 182 rightCount[k] = count2[k]; 183  } 184 splitPoint = j; 185 splitValue = (currentPoint   186 187 lastPoint) / 2; 188  } 189  } 190 all1  ; 191 count1[Convert.ToInt32(currentValue) - 192 193 1]  ; 194 count2[Convert.ToInt32(currentValue) - 195 196 1]--; 197 all2--; 198 lastValue = currentValue; 199 lastPoint = currentPoint; 200  } 201 #endregion 202 #region 如果超过了局部值,重设 203 if (lianxuMax < jubuMax) 204  { 205 info.type = 1; 206 info.splitIndex = i; 207 info.features=new List<string>() 208 209 {splitValue ""}; 210 //finalPoint = splitPoint; 211 jubuMax = lianxuMax; 212 info.temp[0] = new List<int>(); 213 info.temp[1] = new List<int>(); 214 for (int k = 0; k < splitPoint; k  ) 215  { 216 info.temp[0].Add(nums[k]); 217  } 218 for (int k = splitPoint; k < nums.Count; 219 220 k  ) 221  { 222 info.temp[1].Add(nums[k]); 223  } 224 info.class_Count[0] = new double 225 226 [leftCunt.Length]; 227 info.class_Count[1] = new double 228 229 [leftCunt.Length]; 230 for (int k = 0; k < leftCunt.Length; k  ) 231  { 232 info.class_Count[0][k] = leftCunt[k]; 233 info.class_Count[1][k] = rightCount 234 235 [k]; 236  } 237  } 238 #endregion 239  } 240 #endregion 241  } 242 #region 没有寻找到最佳的分裂点,则设置为叶节点 243 if (info.splitIndex == -1) 244  { 245 double[] finalCount = node.ClassCount; 246 double max = finalCount[0]; 247 int result = 1; 248 for (int i = 1; i < finalCount.Length; i  ) 249  { 250 if (finalCount[i] > max) 251  { 252 max = finalCount[i]; 253 result = (i   1); 254  } 255  } 256 node.feature_Type="result"; 257 node.features=new List<String> { ""   result }; 258 return node; 259  } 260 #endregion 261 #region 分裂 262 int deep = node.deep; 263 node.SplitFeature = (""   info.splitIndex); 264 List<Node> childNode = new List<Node>(); 265 int[][] used = new int[2][]; 266 used[0] = new int[isUsed.Length]; 267 used[1] = new int[isUsed.Length]; 268 for (int i = 0; i < isUsed.Length; i  ) 269  { 270 used[0][i] = isUsed[i]; 271 used[1][i] = isUsed[i]; 272  } 273 if (info.type == 0) 274  { 275 used[0][info.splitIndex] = 1; 276 node.feature_Type = ("离散"); 277  } 278 else 279  { 280 //used[info.splitIndex] = 0; 281 node.feature_Type = ("连续"); 282  } 283 List<int>[] rowIndex = info.temp; 284 List<String> features = info.features; 285 Node node1 = new Node(); 286 Node node2 = new Node(); 287 node1.setClassCount(info.class_Count[0]); 288 node2.setClassCount(info.class_Count[1]); 289 node1.rowCount = info.temp[0].Count; 290 node2.rowCount = info.temp[1].Count; 291 node1.deep = deep   1; 292 node2.deep = deep   1; 293 node1 = findBestSplit(node1, info.temp[0],used[0]); 294 node2 = findBestSplit(node2, info.temp[1], used[1]); 295 node.leafNode_Count = (node1.leafNode_Count 296 297  node2.leafNode_Count); 298 node.leafWrong = (node1.leafWrong node2.leafWrong); 299 node.features = (features); 300  childNode.Add(node1); 301  childNode.Add(node2); 302 node.childNodes = childNode; 303 #endregion 304 return node; 305  } 306 catch (Exception e) 307  { 308  Console.WriteLine(e.StackTrace); 309 return node; 310  } 311 }
*/

节点选择属性和分裂

(4)剪枝

代价复杂度剪枝方法(CCP):

代码语言:javascript复制
 1 public static void getSeries(Node node)  2  {  3 Stack<Node> nodeStack = new Stack<Node>();  4 if (node != null)  5  {  6  nodeStack.Push(node);  7  }  8 if (node.feature_Type == "result")  9 return; 10 List<Node> childs = node.childNodes; 11 for (int i = 0; i < childs.Count; i  ) 12  { 13  getSeries(node); 14  } 15 }

CCP代价复杂度剪枝

CART全部核心代码:

代码语言:javascript复制
/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释 
* 1 /// <summary>  2 /// 判断是否还需要分裂  3 /// </summary>  4 /// <param name="node"></param>  5 /// <returns></returns>  6 public static bool ifEnd(Node node, double shang,int[] isUsed)  7  {  8 try  9  {  10 double[] count = node.ClassCount;  11 int rowCount = node.rowCount;  12 int maxResult = 0;  13 double maxRate = 0;  14 #region 数达到某一深度  15 int deep = node.deep;  16 if (deep >= 10)  17  {  18 maxResult = node.result   1;  19 node.feature_Type="result";  20 node.features=new List<String>() { maxResult   ""  21  22 };  23 node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);  24 node.leafNode_Count=1;  25 return true;  26  }  27 #endregion  28 #region 纯度(其实跟后面的有点重了,记得要修改)  29 //maxResult = 1;  30 //for (int i = 1; i < count.Length; i  )  31 //{  32 // if (count[i] / rowCount >= 0.95)  33 // {  34 // node.feature_Type="result";  35 // node.features=new List<String> { ""   (i     36  37 1) };  38 // node.leafNode_Count=1;  39 // node.leafWrong=rowCount - Convert.ToInt32  40  41 (count[i]);  42 // return true;  43 // }  44 //}  45 #endregion  46 #region 熵为0  47 if (shang == 0)  48  {  49 maxRate = count[0] / rowCount;  50 maxResult = 1;  51 for (int i = 1; i < count.Length; i  )  52  {  53 if (count[i] / rowCount >= maxRate)  54  {  55 maxRate = count[i] / rowCount;  56 maxResult = i   1;  57  }  58  }  59 node.feature_Type="result";  60 node.features=new List<String> { maxResult   ""  61  62 };  63 node.leafWrong=rowCount - Convert.ToInt32(count  64  65 [maxResult - 1]);  66 node.leafNode_Count=1;  67 return true;  68  }  69 #endregion  70 #region 属性已经分完  71 //int[] isUsed = node.getUsed();  72 bool flag = true;  73 for (int i = 0; i < isUsed.Length - 1; i  )  74  {  75 if (isUsed[i] == 0)  76  {  77 flag = false;  78 break;  79  }  80  }  81 if (flag)  82  {  83 maxRate = count[0] / rowCount;  84 maxResult = 1;  85 for (int i = 1; i < count.Length; i  )  86  {  87 if (count[i] / rowCount >= maxRate)  88  {  89 maxRate = count[i] / rowCount;  90 maxResult = i   1;  91  }  92  }  93 node.feature_Type=("result");  94 node.features=(new List<String> { ""    95  96 (maxResult) });  97 node.leafWrong=(rowCount - Convert.ToInt32(count  98  99 [maxResult - 1])); 100 node.leafNode_Count=(1); 101 return true; 102  } 103 #endregion 104 #region 几点数少于100 105 if (rowCount < Limit_Node) 106  { 107 maxRate = count[0] / rowCount; 108 maxResult = 1; 109 for (int i = 1; i < count.Length; i  ) 110  { 111 if (count[i] / rowCount >= maxRate) 112  { 113 maxRate = count[i] / rowCount; 114 maxResult = i   1; 115  } 116  } 117 node.feature_Type="result"; 118 node.features=new List<String> { ""   (maxResult) 119 120 }; 121 node.leafWrong=rowCount - Convert.ToInt32(count 122 123 [maxResult - 1]); 124 node.leafNode_Count=1; 125 return true; 126  } 127 #endregion 128 return false; 129  } 130 catch (Exception e) 131  { 132 return false; 133  } 134  } 135 #region 排序算法 136 public static void InsertSort(double[] values, List<int> arr, 137 138 int StartIndex, int endIndex) 139  { 140 for (int i = StartIndex   1; i <= endIndex; i  ) 141  { 142 int key = arr[i]; 143 double init = values[i]; 144 int j = i - 1; 145 while (j >= StartIndex && values[j] > init) 146  { 147 arr[j   1] = arr[j]; 148 values[j   1] = values[j]; 149 j--; 150  } 151 arr[j   1] = key; 152 values[j   1] = init; 153  } 154  } 155 static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high) 156  { 157 int mid = low   ((high - low) >> 1);//计算数组中间的元素的下标 158 159 //使用三数取中法选择枢轴  160 if (values[mid] > values[high])//目标: arr[mid] <= arr[high]  161  { 162  swap(values, arr, mid, high); 163  } 164 if (values[low] > values[high])//目标: arr[low] <= arr[high]  165  { 166  swap(values, arr, low, high); 167  } 168 if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]  169  { 170  swap(values, arr, mid, low); 171  } 172 //此时,arr[mid] <= arr[low] <= arr[high]  173 return low; 174 //low的位置上保存这三个位置中间的值 175 //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了  176  } 177 static void swap(double[] values, List<int> arr, int t1, int t2) 178  { 179 double temp = values[t1]; 180 values[t1] = values[t2]; 181 values[t2] = temp; 182 int key = arr[t1]; 183 arr[t1] = arr[t2]; 184 arr[t2] = key; 185  } 186 static void QSort(double[] values, List<int> arr, int low, int high) 187  { 188 int first = low; 189 int last = high; 190 191 int left = low; 192 int right = high; 193 194 int leftLen = 0; 195 int rightLen = 0; 196 197 if (high - low   1 < 10) 198  { 199  InsertSort(values, arr, low, high); 200 return; 201  } 202 203 //一次分割  204 int key = SelectPivotMedianOfThree(values, arr, low, 205 206 high);//使用三数取中法选择枢轴  207 double inti = values[key]; 208 int currentKey = arr[key]; 209 210 while (low < high) 211  { 212 while (high > low && values[high] >= inti) 213  { 214 if (values[high] == inti)//处理相等元素  215  { 216  swap(values, arr, right, high); 217 right--; 218 rightLen  ; 219  } 220 high--; 221  } 222 arr[low] = arr[high]; 223 values[low] = values[high]; 224 while (high > low && values[low] <= inti) 225  { 226 if (values[low] == inti) 227  { 228  swap(values, arr, left, low); 229 left  ; 230 leftLen  ; 231  } 232 low  ; 233  } 234 arr[high] = arr[low]; 235 values[high] = values[low]; 236  } 237 arr[low] = currentKey; 238 values[low] = values[key]; 239 //一次快排结束 240 //把与枢轴key相同的元素移到枢轴最终位置周围  241 int i = low - 1; 242 int j = first; 243 while (j < left && values[i] != inti) 244  { 245  swap(values, arr, i, j); 246 i--; 247 j  ; 248  } 249 i = low   1; 250 j = last; 251 while (j > right && values[i] != inti) 252  { 253  swap(values, arr, i, j); 254 i  ; 255 j--; 256  } 257 QSort(values, arr, first, low - 1 - leftLen); 258 QSort(values, arr, low   1   rightLen, last); 259  } 260 #endregion 261 /// <summary> 262 /// 寻找最佳的分裂点 263 /// </summary> 264 /// <param name="num"></param> 265 /// <param name="node"></param> 266 public static Node findBestSplit(Node node,List<int> nums,int[] isUsed) 267  { 268 try 269  { 270 //判断是否继续分裂 271 double totalShang = getGini(node.ClassCount, node.rowCount); 272 if (ifEnd(node, totalShang, isUsed)) 273  { 274 return node; 275  } 276 #region 变量声明 277 SplitInfo info = new SplitInfo(); 278  info.initial(); 279 int RowCount = nums.Count; //样本总数 280 double jubuMax = 1; //局部最大熵 281 int splitPoint = 0; //分裂的点 282 double splitValue = 0; //分裂的值 283 #endregion 284 for (int i = 0; i < isUsed.Length - 1; i  ) 285  { 286 if (isUsed[i] == 1) 287  { 288 continue; 289  } 290 #region 离散变量 291 if (type[i] == 0) 292  { 293 double[][] allCount = new double[allNum[i]][]; 294 for (int j = 0; j < allCount.Length; j  ) 295  { 296 allCount[j] = new double[classCount]; 297  } 298 int[] countAllFeature = new int[allNum[i]]; 299 List<int>[] temp = new List<int>[allNum[i]]; 300 double[] allClassCount = node.ClassCount; //所有类别的数量 301 for (int j = 0; j < temp.Length; j  ) 302  { 303 temp[j] = new List<int>(); 304  } 305 for (int j = 0; j < nums.Count; j  ) 306  { 307 int index = Convert.ToInt32(allData[nums[j]][i]); 308 temp[index - 1].Add(nums[j]); 309 countAllFeature[index - 1]  ; 310 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]  ; 311  } 312 double allShang = 1; 313 int choose = 0; 314 315 double[][] jubuCount = new double[2][]; 316 for (int k = 0; k < allCount.Length; k  ) 317  { 318 if (temp[k].Count == 0) 319 continue; 320 double JubuShang = 0; 321 double[][] tempCount = new double[2][]; 322 tempCount[0] = allCount[k]; 323 tempCount[1] = new double[allCount[0].Length]; 324 for (int j = 0; j < tempCount[1].Length; j  ) 325  { 326 tempCount[1][j] = allClassCount[j] - allCount[k][j]; 327  } 328 JubuShang = JubuShang   getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount; 329 int nodecount = RowCount - countAllFeature[k]; 330 JubuShang = JubuShang   getGini(tempCount[1], nodecount) * nodecount / RowCount; 331 if (JubuShang < allShang) 332  { 333 allShang = JubuShang; 334 jubuCount = tempCount; 335 choose = k; 336  } 337  } 338 if (allShang < jubuMax) 339  { 340 info.type = 0; 341 jubuMax = allShang; 342 info.class_Count = jubuCount; 343 info.temp[0] = temp[choose]; 344 info.temp[1] = new List<int>(); 345 info.features = new List<string>(); 346 info.features.Add((choose   1)   ""); 347 info.features.Add(""); 348 for (int j = 0; j < temp.Length; j  ) 349  { 350 if (j == choose) 351 continue; 352 for (int k = 0; k < temp[j].Count; k  ) 353  { 354 info.temp[1].Add(temp[j][k]); 355  } 356 if (temp[j].Count != 0) 357  { 358 info.features[1] = info.features[1]   (j   1)   ","; 359  } 360  } 361 info.splitIndex = i; 362  } 363  } 364 #endregion 365 #region 连续变量 366 else 367  { 368 double[] leftCunt = new double[classCount]; 369 370 //做节点各个类别的数量 371 double[] rightCount = new double[classCount]; 372 373 //右节点各个类别的数量 374 double[] count1 = new double[classCount]; 375 376 //子集1的统计量 377 double[] count2 = new double 378 379 [node.ClassCount.Length]; //子集2的统计量 380 for (int j = 0; j < node.ClassCount.Length; 381 382 j  ) 383  { 384 count2[j] = node.ClassCount[j]; 385  } 386 int all1 = 0; 387 388 //子集1的样本量 389 int all2 = nums.Count; 390 391 //子集2的样本量 392 double lastValue = 0; 393 394 //上一个记录的类别 395 double currentValue = 0; 396 397 //当前类别 398 double lastPoint = 0; 399 400 //上一个点的值 401 double currentPoint = 0; 402 403 //当前点的值 404 double[] values = new double[nums.Count]; 405 for (int j = 0; j < values.Length; j  ) 406  { 407 values[j] = allData[nums[j]][i]; 408  } 409 QSort(values, nums, 0, nums.Count - 1); 410 double lianxuMax = 1; 411 412 //连续型属性的最大熵 413 #region 寻找最佳的分割点 414 for (int j = 0; j < nums.Count - 1; j  ) 415  { 416 currentValue = allData[nums[j]][lieshu - 417 418 1]; 419 currentPoint = (allData[nums[j]][i]); 420 if (j == 0) 421  { 422 lastValue = currentValue; 423 lastPoint = currentPoint; 424  } 425 if (currentValue != lastValue && 426 427 currentPoint != lastPoint) 428  { 429 double shang1 = getGini(count1, 430 431 all1); 432 double shang2 = getGini(count2, 433 434 all2); 435 double allShang = shang1 * all1 / 436 437 (all1   all2)   shang2 * all2 / (all1   all2); 438 //allShang = (totalShang - allShang); 439 if (lianxuMax > allShang) 440  { 441 lianxuMax = allShang; 442 for (int k = 0; k < 443 444 count1.Length; k  ) 445  { 446 leftCunt[k] = count1[k]; 447 rightCount[k] = count2[k]; 448  } 449 splitPoint = j; 450 splitValue = (currentPoint   451 452 lastPoint) / 2; 453  } 454  } 455 all1  ; 456 count1[Convert.ToInt32(currentValue) - 457 458 1]  ; 459 count2[Convert.ToInt32(currentValue) - 460 461 1]--; 462 all2--; 463 lastValue = currentValue; 464 lastPoint = currentPoint; 465  } 466 #endregion 467 #region 如果超过了局部值,重设 468 if (lianxuMax < jubuMax) 469  { 470 info.type = 1; 471 info.splitIndex = i; 472 info.features=new List<string>() 473 474 {splitValue ""}; 475 //finalPoint = splitPoint; 476 jubuMax = lianxuMax; 477 info.temp[0] = new List<int>(); 478 info.temp[1] = new List<int>(); 479 for (int k = 0; k < splitPoint; k  ) 480  { 481 info.temp[0].Add(nums[k]); 482  } 483 for (int k = splitPoint; k < nums.Count; 484 485 k  ) 486  { 487 info.temp[1].Add(nums[k]); 488  } 489 info.class_Count[0] = new double 490 491 [leftCunt.Length]; 492 info.class_Count[1] = new double 493 494 [leftCunt.Length]; 495 for (int k = 0; k < leftCunt.Length; k  ) 496  { 497 info.class_Count[0][k] = leftCunt[k]; 498 info.class_Count[1][k] = rightCount 499 500 [k]; 501  } 502  } 503 #endregion 504  } 505 #endregion 506  } 507 #region 没有寻找到最佳的分裂点,则设置为叶节点 508 if (info.splitIndex == -1) 509  { 510 double[] finalCount = node.ClassCount; 511 double max = finalCount[0]; 512 int result = 1; 513 for (int i = 1; i < finalCount.Length; i  ) 514  { 515 if (finalCount[i] > max) 516  { 517 max = finalCount[i]; 518 result = (i   1); 519  } 520  } 521 node.feature_Type="result"; 522 node.features=new List<String> { ""   result }; 523 return node; 524  } 525 #endregion 526 #region 分裂 527 int deep = node.deep; 528 node.SplitFeature = (""   info.splitIndex); 529 List<Node> childNode = new List<Node>(); 530 int[][] used = new int[2][]; 531 used[0] = new int[isUsed.Length]; 532 used[1] = new int[isUsed.Length]; 533 for (int i = 0; i < isUsed.Length; i  ) 534  { 535 used[0][i] = isUsed[i]; 536 used[1][i] = isUsed[i]; 537  } 538 if (info.type == 0) 539  { 540 used[0][info.splitIndex] = 1; 541 node.feature_Type = ("离散"); 542  } 543 else 544  { 545 //used[info.splitIndex] = 0; 546 node.feature_Type = ("连续"); 547  } 548 List<int>[] rowIndex = info.temp; 549 List<String> features = info.features; 550 Node node1 = new Node(); 551 Node node2 = new Node(); 552 node1.setClassCount(info.class_Count[0]); 553 node2.setClassCount(info.class_Count[1]); 554 node1.rowCount = info.temp[0].Count; 555 node2.rowCount = info.temp[1].Count; 556 node1.deep = deep   1; 557 node2.deep = deep   1; 558 node1 = findBestSplit(node1, info.temp[0],used[0]); 559 node2 = findBestSplit(node2, info.temp[1], used[1]); 560 node.leafNode_Count = (node1.leafNode_Count 561 562  node2.leafNode_Count); 563 node.leafWrong = (node1.leafWrong node2.leafWrong); 564 node.features = (features); 565  childNode.Add(node1); 566  childNode.Add(node2); 567 node.childNodes = childNode; 568 #endregion 569 return node; 570  } 571 catch (Exception e) 572  { 573  Console.WriteLine(e.StackTrace); 574 return node; 575  } 576  } 577 /// <summary> 578 /// GINI值 579 /// </summary> 580 /// <param name="counts"></param> 581 /// <param name="countAll"></param> 582 /// <returns></returns> 583 public static double getGini(double[] counts, int countAll) 584  { 585 double Gini = 1; 586 for (int i = 0; i < counts.Length; i  ) 587  { 588 Gini = Gini - Math.Pow(counts[i] / countAll, 2); 589  } 590 return Gini; 591  } 592 #region CCP剪枝 593 public static void getSeries(Node node) 594  { 595 Stack<Node> nodeStack = new Stack<Node>(); 596 if (node != null) 597  { 598  nodeStack.Push(node); 599  } 600 if (node.feature_Type == "result") 601 return; 602 List<Node> childs = node.childNodes; 603 for (int i = 0; i < childs.Count; i  ) 604  { 605  getSeries(node); 606  } 607  } 608 /// <summary> 609 /// 遍历剪枝 610 /// </summary> 611 /// <param name="node"></param> 612 public static Node getNode1(Node node, Node nodeCut) 613  { 614 615 //List<Node> childNodes = node.getChild(); 616 //double min = 100000; 617 ////Node nodeCut = new Node(); 618 //double temp = 0; 619 //for (int i = 0; i < childNodes.Count; i  ) 620 //{ 621 // if (childNodes[i].getType() != "result") 622 // { 623 // //if (!cutTree(childNodes[i])) 624 // temp = min; 625 // min = cutTree(childNodes[i], min); 626 // if (min < temp) 627 // nodeCut = childNodes[i]; 628 // getNode1(childNodes[i], nodeCut); 629 // } 630 //} 631 //node.setChildNode(childNodes); 632 return null; 633  } 634 /// <summary> 635 /// 对每一个节点剪枝 636 /// </summary> 637 public static double cutTree(Node node, double minA) 638  { 639 int rowCount = node.rowCount; 640 double leaf = node.getErrorCount(); 641 double[] values = getError1(node, 0, 0); 642 double treeWrong = values[0]; 643 double son = values[1]; 644 double rate = (leaf - treeWrong) / (son - 1); 645 if (minA > rate) 646 minA = rate; 647 //double var = Math.Sqrt(treeWrong * (1 - treeWrong /  648 649 rowCount)); 650 //double panbie = treeWrong   var - leaf; 651 //if (panbie > 0) 652 //{ 653 // node.setFeatureType("result"); 654 // node.setChildNode(null); 655 // int result = (node.getResult()   1); 656 // node.setFeatures(new List<String>() { ""   result  657 658 }); 659 // //return true; 660 //} 661 return minA; 662  } 663 /// <summary> 664 /// 获得子树的错误个数 665 /// </summary> 666 /// <param name="node"></param> 667 /// <returns></returns> 668 public static double[] getError1(Node node, double treeError, 669 670 double son) 671  { 672 if (node.feature_Type == "result") 673  { 674 675 double error = node.getErrorCount(); 676 son  ; 677 return new double[] { treeError   error, son }; 678  } 679 List<Node> childNode = node.childNodes; 680 for (int i = 0; i < childNode.Count; i  ) 681  { 682 double[] values = getError1(childNode[i], treeError, 683 684 son); 685 treeError = values[0]; 686 son = values[1]; 687  } 688 return new double[] { treeError, son }; 689  } 690 #endregion
*/

CART核心代码

总结:

(1)CART是一棵二叉树,每一次分裂会产生两个子节点,对于连续性的数据,直接采用与C4.5相似的处理方法,对于离散型数据,选择最优的两种离散值组合方法。

(2)CART既能是分类数,又能是二叉树。如果是分类树,将选择能够最小化分裂后节点GINI值的分裂属性;如果是回归树,选择能够最小化两个节点样本方差的分裂属性。

(3)CART跟C4.5一样,需要进行剪枝,采用CCP(代价复杂度的剪枝方法)。

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/166831.html原文链接:https://javaforall.cn

0 人点赞