用Golang实现K-NN算法

2024-02-09 01:08:50 浏览数 (2)

K近邻算法用来对观察数据打标签/分类。通过和已打标样本对比 两者距离,跟哪个样本近就标注该观察数据应该归为什么标签。这通常也是机器学习的一个基础入门算法。

比如说图片这个例子, 这里有5000个28x28像素点阵的灰阶 (0-255) 手写体图像 ,画了数字0到 9。这是部分的示例。

5000个数字图像包括了训练集。并且我们会有一些新的手写体图像,这些还未进行分类打标签。这些未分类图像我们所知的只是灰阶像素点阵。算法工作是通过找到未分类图像在训练集中和哪个样本最接近。最合理的预测就是最接近的图像就是拥有那个样本的标签(所谓的物以类聚)。这也就是预测未知数据。

最接近的方法,就是我们去查找未分类图像和每个样本图片去比较。挨个像素比较,然后汇总所有像素点的比较结果。总体来看,汇总结果差异越小,那么就是越像那个分类。

标准测量差异方法叫做 Euclidean策略。假设两个vectors x⃗, y⃗,向量长度是28 × 28 = 784。包含了8-bit 非负数0…255, 然后定义距离就是

这个问题我们同时给了500个图像来分类, 并且他们是验证集。在跑完所有的验证集的500个数据, 来计算预测准确率 (这里已知标签, 假装不知道他们是如何分类的),

CSV文件包含了训练集和验证集。 每一行对应一个图像. 第一栏是标签, 后面的784栏是每个像素点的灰阶数字.

这里描述的k-近邻分类中的k = 1.

代码语言:go复制
package main

import (
	"bytes"
	"fmt"
	"io/ioutil"
	"strconv"
)

type LabelWithFeatures struct {
	Label    []byte  // 标签
	Features []float64  // 特征列表
}

func NewLabelWithFeatures(parsedLine [][]byte) LabelWithFeatures {  // 读取样本
	label := parsedLine[0]
	features := make([]float64, len(parsedLine)-1)

	for i, feature := range parsedLine {
		// skip label
		if i == 0 {
			continue
		}

		features[i-1] = byteSliceTofloat64(feature)
	}

	return LabelWithFeatures{label, features}  // 返回数据结构
}

var newline = []byte("n")
var comma = []byte(",")

func byteSliceTofloat64(b []byte) float64 {
	x, _ := strconv.ParseFloat(string(b), 32)
	return x
}

func parseCSVFile(filePath string) []LabelWithFeatures {  // 读取CSV数据,生成样本集的方法
	fileContent, _ := ioutil.ReadFile(filePath)
	lines := bytes.Split(fileContent, newline)
	numRows := len(lines)

	labelsWithFeatures := make([]LabelWithFeatures, numRows-2)

	for i, line := range lines {
		// skip headers
		if i == 0 || i == numRows-1 {
			continue
		}

		labelsWithFeatures[i-1] = NewLabelWithFeatures(bytes.Split(line, comma))
	}

	return labelsWithFeatures
}

func squareDistance(features1, features2 []float64) (d float64) {  
  for i := 0; i < len(features1); i   {  // 遍历所有特征
    d  = (features1[i] - features2[i]) * (features1[i] - features2[i]) // 特征距离平方之和
  }

  return
}

var trainingSample = parseCSVFile("trainingsample.csv")

func classify(features []float64) (label []byte) {
	label = trainingSample[0].Label
	d := squareDistance(features, trainingSample[0].Features)

	for _, row := range trainingSample {  // 在所有样本中 遍历查找
		dNew := squareDistance(features, row.Features)  // 计算未分类数据和样本间的距离
		if dNew < d {  // 找到距离最小的那个样本
			label = row.Label  // 这个样本的标签就是未分类数据的标签
			d = dNew
		}
	}

	return
}

func main() {
	validationSample := parseCSVFile("validationsample.csv")  // 导入样本

	totalCorrect := 0

	for _, test := range validationSample {  // 验证集合
		if string(test.Label) == string(classify(test.Features)) {
			totalCorrect    // 验证模型预测准确次数 1
		}
	}

	fmt.Println(float64(totalCorrect) / float64(len(validationSample))) // 打印模型预测准确率
}

0 人点赞