简单说说K均值聚类

2022-07-29 19:30:59 浏览数 (1)

聚类是一个将数据集中在某些方面相似的数据成员进行分类组织的过程,聚类就是一种发现这种内在结构的技术,聚类技术经常被称为无监督学习。k均值聚类是最著名的划分聚类算法,由于简洁和效率使得他成为所有聚类算法中最广泛使用的。给定一个数据点集合和需要的聚类数目k,k由用户指定,k均值算法根据某个距离函数反复把数据分入k个聚类中。

假设对基本的二维平面上的点进行K均值聚类,其实现基本步骤是:

1.事先选定好K个聚类中心(假设要分为K类)。2.算出每一个点到这K个聚类中心的距离,然后把该点分配给距离它最近的一个聚类中心。3.更新聚类中心。算出每一个类别里面所有点的平均值,作为新的聚类中心。4.给定迭代此次数,不断重复步骤2和步骤3,达到该迭代次数后自动停止。

思想很简单,实现起来也很简单,附上代码:

代码语言:javascript复制
import numpy as np
import matplotlib.pyplot as plt

#np.random.seed(300)
x=np.random.rand(200)*15    #产生要聚类的数据点,(0,15)之间
y=np.random.rand(200)*15

center_x=[]    #存放聚类中心坐标
center_y=[]
result_x=[]    #存放每次迭代后每一小类的坐标
result_y=[]

number_cluster=4   #簇数
time=50   #迭代次数

color=['red','blue','black','orange']

for i in range(number_cluster):  # 随机生成中心
    result_x.append([])      #顺便初始化存放聚类结果的列表
    result_y.append([])
    x1 = np.random.choice(x)  #为了避免出现聚类后有的簇一个点也没有,
    y1 = np.random.choice(y)  #干脆就以某一个数据点为中心
    if x1 not in center_x and y1 not in center_y:
        center_x.append(x1)
        center_y.append(y1)

plt.scatter(x,y)  #画出数据图
plt.title('init plot')
plt.show()

def K_means():
    for t in range(time):
        for i in range(len(x)):
            distance = []   #存放每个点到各中心的距离
            for j in range(len(center_x)):
                k = (center_x[j] - x[i]) ** 2   (center_y[j] - y[i]) ** 2  #距离
                distance.append([k])
            result_x[distance.index(min(distance))].append(x[i])  #聚类
            result_y[distance.index(min(distance))].append(y[i])
        plt.title('iterations:' str(t 1))
        for i in range(number_cluster):
            plt.scatter(result_x[i], result_y[i], c=color[i])
        plt.show()

        # 更新位置
        center_x.clear()
        center_y.clear()
        for i in range(number_cluster):
            ave_x = np.mean(result_x[i])
            ave_y = np.mean(result_y[i])
            center_x.append(ave_x)
            center_y.append(ave_y)


if __name__=='__main__':
    K_means()

结果展示:

1.初始化:

2.第一次迭代:

3.第二次迭代:

4.第九次迭代(收敛):

0 人点赞