K最近邻算法(KNN)介绍及实现

2020-11-12 11:46:03 浏览数 (1)

KNN,即K nearest neighbor,K近邻算法。KNN的思想非常简单,所需的数学知识较少。比如下图,星星是一个新的样本,要判断星星是属于蓝色的还是黄色的样本分类,就要看它周围的邻居是什么分类。假设K=3,就是看周围三个点的分类,如图,周围有两个红点,一个黄点,应该归类为红色类别。

导入数据:

代码语言:javascript复制
from sklearn.datasets import load_breast_cancer

cancer = load_breast_cancer()
data = cancer.data
target = cancer.target

import numpy as np

X = np.array([20, 30])
plt.scatter(data[target==1,0], data[target==1,1], alpha = 0.5, color = 'green', label = '1')
plt.scatter(data[target==0,0], data[target==0,1], alpha = 0.5, color = 'orange', label = '0')
plt.plot(X[0], X[1], color = 'purple', marker = '*', ms = 20)
plt.legend(loc = 'upper right')
plt.show()

KNN的计算中,先计算距离,比较常见的是欧拉距离:

也就是两个点(或者多个点)对应的横纵坐标差的平方和,然后开平方。

根据欧拉距离写一个KNN的实现:

代码语言:javascript复制
def KNN_test(X_train, y_train, test, K):
    distance = []

    for t in X_train:
        d = sqrt(np.sum((t - test)**2))
        distance.append(d)  
    
    ind = np.argsort(distance)
    topK_target = y_train[ind[:K]]
    c = Counter(topK_target)
    return c.most_common(1)[0][0]

判断刚才的点是属于哪一类:

代码语言:javascript复制
KNN_test(data[:, :2], target, X, 6)

sklearn中的实现:

代码语言:javascript复制
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=6)
knn.fit(data[:, :2], target)
knn.predict(X.reshape(1, -1))
knn

0 人点赞