KNN:最容易理解的分类算法

2021-02-08 21:09:15 浏览数 (1)

欢迎关注”生信修炼手册”!

KNN是一种分类算法,其全称为k-nearest neighbors, 所以也叫作K近邻算法。该算法是一种监督学习的算法,具体可以分为以下几个步骤

1. 第一步,载入数据,因为是监督学习算法,所以要求输入数据中必须提供样本对应的分类信息

2. 第二步,指定K值,为了避免平票,K值一般是奇数

3. 第三步,对于待分类的样本点,计算该样本点与输入样本的距离矩阵,按照距离从小到大排序,选择K个最近的点

4. 第四步,根据K个点的分类频率,确定频率最高的类别为该样本点的最终分类

可以通过下图加以理解

黑色样本点为待分类点,对于图上的点而言,分成了红色和紫色两大类。指定K为3,则在最近的3个点中,2个是红点,1个是紫点,所以该黑色的点应该归为红色类。

根据这个分类逻辑,K的取值对样本的分类会有很大影响,以下图为例

K值为3时,绿色的点归类为红色,K值为5时,绿色的点归类为蓝色。由此可见,K值的选取是模型的核心因素之一。

除此之外,还有另外一个因素,就是距离的计算。距离的计算有多种方法,比如欧式距离,曼哈顿距离等等,不同距离度量方式会影响最近的K个样本点的选取,从而对结果造成影响,在实际分析中,一般都采用的是欧式距离,所以要求对输入数据进行归一化。

在scikit-learn中,使用KNN算法的代码如下

代码语言:javascript复制
>>> from sklearn.neighbors import KNeighborsClassifier
>>> X = [[0], [1], [2], [3]]
>>> y = [0, 0, 1, 1]
>>> neigh = KNeighborsClassifier(n_neighbors=3)
>>> neigh.fit(X, y)
KNeighborsClassifier(n_neighbors=3)
>>> print(neigh.predict([[1.1]]))
[0]

KNN算法原理简单,适用于样本量较小(小于10万个)同时特征较少的数据集。对于K值的合适取值,可以通过验证集上的表现或者交叉验证来识别。

·end·

0 人点赞