聚类算法比较

2022-05-29 10:20:26 浏览数 (1)

代码语言:javascript复制
import time
import warnings
import numpy as np
import matplotlib.pyplot as plt
from sklearn import cluster, datasets, mixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler
from itertools import cycle, islice
np.random.seed(0)
n_samples=1000
noisy_circles=datasets.make_circles(n_samples=n_samples,factor=.3,noise=.03)
noisy_moons=datasets.make_moons(n_samples=n_samples,noise=.03)
blobs=datasets.make_blobs(n_samples=n_samples,random_state=8)
no_structure=np.random.rand(n_samples,2),None
random_state=160
X,y=datasets.make_blobs(n_samples=n_samples,random_state=random_state)
transformation=[[0.6,-0.6],[-0.4,0.8]]
X_aniso=np.dot(X,transformation)
aniso=(X_aniso,y)
#变化幅度
varied=datasets.make_blobs(n_samples=n_samples,
                             cluster_std=[1.0,2.5,0.5],
                             random_state=random_state)
#设置聚类参数
plt.figure(figsize=(9*2 3,12.5))
plt.subplots_adjust(left=.02,right=.98,bottom=.001,top=.96,wspace=.03,hspace=.01)
plot_num=1
default_base={'quantile':.3,'eps':.3,'damping':.9,'preference':-200,'n_neighbors':10,'n_clusters':3}
datasets=[(noisy_circles,{'damping':.77,'preference':-240,'quantile':.2,'n_clusters':2}),(noisy_moons,{'damping':.75,'preference':-220,'n_clusters':2}),(varied, {'eps': .18, 'n_neighbors': 2}),(aniso, {'eps': .15, 'n_neighbors': 2}),(blobs, {}),(no_structure, {})]
for i_dataset, (dataset, algo_params) in enumerate(datasets):
    #更新参数
    params=default_base.copy()
    params.update(algo_params)
    X,y=dataset
    #规范化数据集以便于参数选择 
    X=StandardScaler().fit_transform(X)
    #估计带宽
    bandwidth=cluster.estimate_bandwidth(X, quantile=params['quantile'])
    connectivity=kneighbors_graph(X,n_neighbors=params['n_neighbors'],include_self=False)
    #对称的连通性
    connectivity=0.5*(connectivity connectivity.T)
    #创建聚类
    ms=cluster.MeanShift(bandwidth=bandwidth,bin_seeding=True)
    two_means=cluster.MiniBatchKMeans(n_clusters=params['n_clusters'])
    ward=cluster.AgglomerativeClustering(n_clusters=params['n_clusters'],linkage='ward',connectivity=connectivity)
    spectral=cluster.SpectralClustering(
        n_clusters=params['n_clusters'], eigen_solver='arpack',
        affinity="nearest_neighbors")
    dbscan = cluster.DBSCAN(eps=params['eps'])
    affinity_propagation = cluster.AffinityPropagation(
        damping=params['damping'], preference=params['preference'])
    average_linkage = cluster.AgglomerativeClustering(
        linkage="average", affinity="cityblock",
        n_clusters=params['n_clusters'],connectivity=connectivity)
    birch=cluster.Birch(n_clusters=params['n_clusters'])
    gmm=mixture.GaussianMixture(
        n_components=params['n_clusters'],covariance_type='full')
    clustering_algorithms=(('MiniBatchKMeans', two_means),('AffinityPropagation', affinity_propagation),('MeanShift', ms),('SpectralClustering', spectral),('Ward', ward),('AgglomerativeClustering', average_linkage),('DBSCAN',dbscan),('Birch',birch),('GaussianMixture',gmm))
    for name, algorithm in clustering_algorithms:
        t0=time.time()
        #警告
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore",
                message="the number of connected components of the " "connectivity matrix is [0-9]{1,2}"  ">1. Completing it to avoid stopping the tree early.",category=UserWarning)
            warnings.filterwarnings("ignore",message="Graph is not fully connected, spectral embedding" "may not work as expected.",category=UserWarning)
            algorithm.fit(X)
        t1=time.time()
        if hasattr(algorithm,'labels_'):
            y_pred=algorithm.labels_.astype(np.int)
        else:
            y_pred=algorithm.predict(X)
        plt.subplot(len(datasets),len(clustering_algorithms),plot_num)
        if i_dataset==0:
            plt.title(name,size=18)
        colors=np.array(list(islice(cycle(['#377eb8','#ff7f00','#4daf4a', '#f781bf','#a65628','#984ea3','#999999','#e41a1c','#dede00']),int(max(y_pred) 1))))
        #异常值
        colors=np.append(colors,["#000000"])
        plt.scatter(X[:,0],X[:,1],s=10,color=colors[y_pred])
        plt.xlim(-2.5,2.5)
        plt.ylim(-2.5,2.5)
        plt.xticks(())
        plt.yticks(())
        plt.text(.99,.01,('%.2fs' % (t1 - t0)).lstrip('0'),
                 transform=plt.gca().transAxes,size=15,
                 horizontalalignment='right')
        plot_num =1
plt.show()

算法:聚类算法比较是包括MiniBatchKMeans、AP聚类、MeanShift、谱聚类、Ward聚类、层次聚类、DBSCAN聚类、Birch聚类和高斯混合模型聚类算法的参数被优化到最佳聚类的结果比较。

0 人点赞