机器学习系列--KNN分类算法例子

2023-06-29 16:08:23 浏览数 (2)

url:机器学习系列--KNN分类算法

用的是spark2.0.2,scala2.11

import org.apache.spark.{SparkConf, SparkContext}

object knntest {

  /**     * 欧式距离     * 计算两点间的距离     * @param rs as r1,r2, ..., rd     * @param ss as s1,s2, ..., sd     * @param d 维数     */   def euclideanDistance(rs: String, ss: String, d: Int): Double = {     val r = rs.split(",").map(_.toDouble)     val s = ss.split(",").map(_.toDouble)

    if (r.length != d || s.length != d) Double.NaN else {       //zip匹配key/value 分区数一样,ri-si的平方的求和再开方,欧式距离       math.sqrt((r, s).zipped.take(d).map {         case (ri, si) => math.pow(ri - si, 2)       }.sum)     }   }

  def main(args: Array[String]): Unit = {     val sparkConf=new SparkConf().setAppName("knntest").setMaster("local[4]")     val sc=new SparkContext(sparkConf)

    //生成矩阵,每行代表一个样本 10为索引,A,B为类别,其它为属性1,2..     val groupes=sc.parallelize(List("10;A;1.0,0.9", "11;A;1.0,1.0", "12;B;0.1,0.2", "13;B;0.0,0.1"))     //100为索引,其它为属性1,2..     val testxs = sc.parallelize(List("100;1.2,1.0","101;0.1,0.3"))     //近邻数     val k = sc.broadcast(3)     //向量维度     val d = sc.broadcast(2)     //笛卡尔     //ArrayBuffer((100;1.2,1.0,10;A;1.0,0.9),     // (100;1.2,1.0,11;A;1.0,1.0),     // (100;1.2,1.0,12;B;0.1,0.2),     // (100;1.2,1.0,13;B;0.0,0.1),     // (101;0.1,0.3,10;A;1.0,0.9),     // (101;0.1,0.3,11;A;1.0,1.0),     // (101;0.1,0.3,12;B;0.1,0.2),     // (101;0.1,0.3,13;B;0.0,0.1))     val cart=testxs.cartesian(groupes)

    val knns=cart.map(p=>{       val testx=p._1//例 100;1.2,1.0       val group2=p._2//例 10;A;1.0,0.9       val testx_index=testx.split(";")(0)       val testx_rs=testx.split(";")(1)

      //类型       val group2_type=group2.split(";")(1)       val group2_ss=group2.split(";")(2)       //欧式距离       val distance =euclideanDistance(testx_rs, group2_ss, d.value)       //ArrayBuffer((100,(0.2236067977499789,A)), (100,(0.19999999999999996,A)),       // (100,(1.3601470508735443,B)), (100,(1.5,B)),       // (101,(1.0816653826391969,A)), (101,(1.140175425099138,A)),       // (101,(0.09999999999999998,B)), (101,(0.22360679774997896,B)))       (testx_index,(distance,group2_type))     })

    val knnGrouped = knns.groupByKey()

    val knnOutput = knnGrouped.mapValues(itr => {       //(100,List((0.19999999999999996,A), (0.2236067977499789,A), (1.3601470508735443,B)))       //(101,List((0.09999999999999998,B), (0.22360679774997896,B), (1.0816653826391969,A)))       val nearestK = itr.toList.sortBy(_._1).take(k.value)       //(101,List((B,1), (B,1), (A,1)))       //(100,List((A,1), (A,1), (B,1)))       //(100,Map(A -> 2, B -> 1))       //(101,Map(A -> 1, B -> 2))       val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {         val (stringList, intlist) = list.unzip         intlist.sum       })       //(100,A)       //(101,B)       majority.maxBy(_._2)._1     })

    knnOutput.foreach(println)     sc.stop()   } }

0 人点赞