/** * 欧式距离 * 计算两点间的距离 * @param rs as r1,r2, ..., rd * @param ss as s1,s2, ..., sd * @param d 维数 */ def euclideanDistance(rs: String, ss: String, d: Int): Double = { val r = rs.split(",").map(_.toDouble) val s = ss.split(",").map(_.toDouble)
if (r.length != d || s.length != d) Double.NaN else { //zip匹配key/value 分区数一样,ri-si的平方的求和再开方,欧式距离 math.sqrt((r, s).zipped.take(d).map { case (ri, si) => math.pow(ri - si, 2) }.sum) } }
def main(args: Array[String]): Unit = { val sparkConf=new SparkConf().setAppName("knntest").setMaster("local[4]") val sc=new SparkContext(sparkConf)
//生成矩阵,每行代表一个样本 10为索引,A,B为类别,其它为属性1,2.. val groupes=sc.parallelize(List("10;A;1.0,0.9", "11;A;1.0,1.0", "12;B;0.1,0.2", "13;B;0.0,0.1")) //100为索引,其它为属性1,2.. val testxs = sc.parallelize(List("100;1.2,1.0","101;0.1,0.3")) //近邻数 val k = sc.broadcast(3) //向量维度 val d = sc.broadcast(2) //笛卡尔 //ArrayBuffer((100;1.2,1.0,10;A;1.0,0.9), // (100;1.2,1.0,11;A;1.0,1.0), // (100;1.2,1.0,12;B;0.1,0.2), // (100;1.2,1.0,13;B;0.0,0.1), // (101;0.1,0.3,10;A;1.0,0.9), // (101;0.1,0.3,11;A;1.0,1.0), // (101;0.1,0.3,12;B;0.1,0.2), // (101;0.1,0.3,13;B;0.0,0.1)) val cart=testxs.cartesian(groupes)
val knns=cart.map(p=>{ val testx=p._1//例 100;1.2,1.0 val group2=p._2//例 10;A;1.0,0.9 val testx_index=testx.split(";")(0) val testx_rs=testx.split(";")(1)
//类型 val group2_type=group2.split(";")(1) val group2_ss=group2.split(";")(2) //欧式距离 val distance =euclideanDistance(testx_rs, group2_ss, d.value) //ArrayBuffer((100,(0.2236067977499789,A)), (100,(0.19999999999999996,A)), // (100,(1.3601470508735443,B)), (100,(1.5,B)), // (101,(1.0816653826391969,A)), (101,(1.140175425099138,A)), // (101,(0.09999999999999998,B)), (101,(0.22360679774997896,B))) (testx_index,(distance,group2_type)) })
val knnGrouped = knns.groupByKey()
val knnOutput = knnGrouped.mapValues(itr => { //(100,List((0.19999999999999996,A), (0.2236067977499789,A), (1.3601470508735443,B))) //(101,List((0.09999999999999998,B), (0.22360679774997896,B), (1.0816653826391969,A))) val nearestK = itr.toList.sortBy(_._1).take(k.value) //(101,List((B,1), (B,1), (A,1))) //(100,List((A,1), (A,1), (B,1))) //(100,Map(A -> 2, B -> 1)) //(101,Map(A -> 1, B -> 2)) val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => { val (stringList, intlist) = list.unzip intlist.sum }) //(100,A) //(101,B) majority.maxBy(_._2)._1 })