对于监督学习而言,回归和分类是两类基本应用场景;对于非监督学习而言,则是聚类和降维。K-means属于聚类算法的一种,通过迭代将样本分为K个互不重叠的子集。
对于K-means聚类而言,首先要确定的第一个参数就是聚类个数K。具体的方法有以下两种,第一种是目的导向,根据先验知识或者研究目的,直接给定一个具体的K值,比如根据实验设计的分组数目定K值,根据样本的不同来源定K值等;第二种方法称之为Elbow, 适合没有任何先验的数据,通过比较多个K值的聚类结果,选取拐点值,图示如下
横坐标为不同的K值,纵坐标为样本点到聚类中心的距离总和。
K-means是一种启发式的聚类算法,通过迭代的方式来求解,在初次迭代时,随机选择两个样本点作为聚类的中心点,这样的中心点也叫做质心centroids,然后不断循环重复如下两个过程
1. cluster assignment,计算样本与聚类中心点的距离,选择距离近的中心点作为该样本的分类
2. move centroid, 移动聚类中心点,样本分类完毕之后,重新计算各个cluster的中心点
经过多次迭代,直到中心点的位置不在发生变化即可。下面用一系列示例图来展示其迭代过程,输入数据如下
根据先验知识,确定样本划分为两类,首先随机选择聚类的中心点
计算样本与中心点的距离,将样本划分为不同的cluster
根据划分好的结果,重新计算聚类中心点
重复迭代,直到中心点的位置不再变动,得到最终的聚类结果
在kmeans算法中,初始聚类中心点的选取对算法收敛的速度和结果都有很大影响。在传统kemans的基础上,又提出了kmeans 算法,该算法的不同之处在于初始聚类中心点的选取策略,其他步骤和传统的kmeans相同。
kmeans 的初始聚类中心选择策略如下
1. 随机选取一个样本作为聚类中心
2. 计算每个样本点与该聚类中心的距离,选择距离最大的点作为聚类中心点
3. 重复上述步骤,直到选取K个中心点
在scikit-learn中,使用kmeans聚类的代码如下
代码语言:javascript复制>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from sklearn.cluster import KMeans
>>> from sklearn.datasets import make_blobs
>>> from sklearn.metrics.pairwise import pairwise_distances_argmin
>>> centers = [[1, 1], [-1, -1], [1, -1]]
>>> n_clusters = len(centers)
>>> X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7)
>>> k_means = KMeans(init='k-means ', n_clusters=3, n_init=10)
>>> k_means.fit(X)
KMeans(n_clusters=3)
对于聚类结果,可以用以下代码进行可视化
代码语言:javascript复制>>> k_means_cluster_centers = k_means.cluster_centers_
>>> k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers)
>>> colors = ['#4EACC5', '#FF9C34', '#4E9A06']
>>> fig, ax = plt.subplots()
>>> for k, col in zip(range(n_clusters), colors):
... my_members = k_means_labels == k
... cluster_center = k_means_cluster_centers[k]
... ax.plot(X[my_members, 0], X[my_members, 1], 'w', markerfacecolor=col, marker='.')
... ax.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col, markeredgecolor='k', markersize=6)
...
[<matplotlib.lines.Line2D object at 0x11322880>]
[<matplotlib.lines.Line2D object at 0x11322A48>]
[<matplotlib.lines.Line2D object at 0x11322BF8>]
[<matplotlib.lines.Line2D object at 0x11322DA8>]
[<matplotlib.lines.Line2D object at 0x11322F58>]
[<matplotlib.lines.Line2D object at 0x11330130>]
>>> ax.set_title('KMeans')
Text(0.5, 1.0, 'KMeans')
>>> ax.set_xticks(())
[]
>>> ax.set_yticks(())
[]
>>> plt.show()
输出结果如下
kmeans算法原理简单,运算速度快,适用于大样本的数据,但是注意由于采用了欧氏距离,需要在数据预处理阶段进行归一化处理。
·end·
—如果喜欢,快分享给你的朋友们吧—
原创不易,欢迎收藏,点赞,转发!生信知识浩瀚如海,在生信学习的道路上,让我们一起并肩作战!
本公众号深耕耘生信领域多年,具有丰富的数据分析经验,致力于提供真正有价值的数据分析服务,擅长个性化分析,欢迎有需要的老师和同学前来咨询。