数学建模学习笔记(十九)K-means聚类的matlab和python实现

2022-06-14 09:50:30 浏览数 (1)

在本专栏前面几篇中曾记录了一下K-means的matlab代码,这次使用时发现并不好用,因此又整理了其他的K-means代码,实测可行。

matlab:

代码语言:javascript复制
%% K-mens方法的matlab实现
%% 数据准备和初始化
clc
clear 
x=[62,627;112,511;186,531;198,411;190,379;234,399;227,598;329,454;349,596;424,600;611,565;811,736;776,537;666,437;944,449;943,318;743,216;1076,252;899,178;995,91;1074,101;943,17;275,341];
z=zeros(2,2);
z1=zeros(2,2);
z=x(1:2,1:2);
%% 寻找聚类中心
while 1
    count=zeros(2,1);
    allsum=zeros(2,2);
    for i=1:23 %对每一个样本i,计算到2个聚类中心的距离
       temp1=sqrt((z(1,1)-x(i,1)).^2 (z(1,2)-x(i,2)).^2);
       temp2=sqrt((z(2,1)-x(i,1)).^2 (z(2,2)-x(i,2)).^2);
        if(temp1<temp2)
            count(1)=count(1) 1;
            allsum(1,1)=allsum(1,1) x(i,1);
            allsum(1,2)=allsum(1,2) x(i,2);
        else
            count(2)=count(2) 1;
            allsum(2,1)=allsum(2,1) x(i,1);
            allsum(2,2)=allsum(2,2) x(i,2);
        end
    end
    z1(1,1)=allsum(1,1)/count(1);
    z1(1,2)=allsum(1,2)/count(1);
    z1(2,1)=allsum(2,1)/count(2);
    z1(2,2)=allsum(2,2)/count(2);
    if(z==z1)
        break;
    else
        z=z1;
    end
end
%% 结果显示
disp(z1);%输出聚类终须
plot(x(:,1),x(:,2),'k*',...
    'LineWidth',2,...
    'MarkerSize',10,...
    'MarkerEdgeColor','k',...
    'MarkerFaceColor',[0.5,0.5,0.5])
hold on
plot(z1(:,1),z1(:,2),'ko',...
        'LineWidth',2,...
    'MarkerSize',10,...
    'MarkerEdgeColor','k',...
    'MarkerFaceColor',[0.5,0.5,0.5])
set(gca,'linewidth',2);
xlabel('x','fontsize',12);
ylabel('y','fontsize',12);

代码效果:

不过这个只能实现2种聚类

python代码:

代码语言:javascript复制
# -*- coding:utf-8 -*-
import numpy as np
from matplotlib import pyplot


class K_Means(object):
    # k是分组数;tolerance‘中心点误差’;max_iter是迭代次数
    def __init__(self, k=2, tolerance=0.0001, max_iter=300):
        self.k_ = k
        self.tolerance_ = tolerance
        self.max_iter_ = max_iter

    def fit(self, data):
        self.centers_ = {}
        for i in range(self.k_):
            self.centers_[i] = data[i]

        for i in range(self.max_iter_):
            self.clf_ = {}
            for i in range(self.k_):
                self.clf_[i] = []
            # print("质点:",self.centers_)
            for feature in data:
                # distances = [np.linalg.norm(feature-self.centers[center]) for center in self.centers]
                distances = []
                for center in self.centers_:
                    # 欧拉距离
                    # np.sqrt(np.sum((features-self.centers_[center])**2))
                    distances.append(np.linalg.norm(feature - self.centers_[center]))
                classification = distances.index(min(distances))
                self.clf_[classification].append(feature)

            # print("分组情况:",self.clf_)
            prev_centers = dict(self.centers_)
            for c in self.clf_:
                self.centers_[c] = np.average(self.clf_[c], axis=0)

            # '中心点'是否在误差范围
            optimized = True
            for center in self.centers_:
                org_centers = prev_centers[center]
                cur_centers = self.centers_[center]
                if np.sum((cur_centers - org_centers) / org_centers * 100.0) > self.tolerance_:
                    optimized = False
            if optimized:
                break

    def predict(self, p_data):
        distances = [np.linalg.norm(p_data - self.centers_[center]) for center in self.centers_]
        index = distances.index(min(distances))
        return index


if __name__ == '__main__':
    x = np.array([[149, 663], [404, 707], [743, 754], [170, 511], [520, 490], [912, 500], [303, 287], [810, 246], [1011, 298], [653, 33]])
    k_means = K_Means(k=2)
    k_means.fit(x)
    print(k_means.centers_)
    for center in k_means.centers_:
        pyplot.scatter(k_means.centers_[center][0], k_means.centers_[center][1], marker='*', s=150)

    for cat in k_means.clf_:
        for point in k_means.clf_[cat]:
            pyplot.scatter(point[0], point[1], c=('r' if cat == 0 else 'b'))

    predict = [[2, 1], [6, 9]]
    for feature in predict:
        cat = k_means.predict(predict)

    pyplot.show()

修改k值即可实现聚几类,不过只能实现1,2

更多类的聚类有待后续挖掘…

0 人点赞