机器学习-Mean Shift聚类算法

2022-04-28 13:13:08 浏览数 (1)

Mean Shift算法建立在核密度估计(kernel density estimation,KDE)的基础之上,它假设数据点集是从Probability Distribution中采样获取的,Kernel Density Estimation是从数据点集估计Probability Distribution的非参数估计方法。

1.Kernel Density Estimation

给定n个数据点

,使用Radially Symmetric Kernel的Multivariate Kernel Density Estimate的形式如下:

h是BandWidth,不同BandWidth,聚类的效果也会不同。Radially Symmetric Kernel的定义如下:

是Normalization Constant。

如上图所示,红色点是待估计数据点集,Kernel Density Estimation的工作原理是在数据点集的每个Point放置一个Kernel(Kernel的实质是加权函数,Kernel的种类很多,比较常用的Gaussian Kernel),将所有的单个Kernel加起来就生成Probability Surface。所使用的Kernel BandWidth参数不用,生成的密度函数将有所不同。

surface plot using a Gaussian kernel with a kernel bandwidth of 2

contour plot of the surface using a Gaussian kernel with a kernel bandwidth of 2

2.Mean Shift

Mean Shift想法很简单:迭代的将将所有的Point汇集到Kernel Density Estimation Surface上距离最近的Peak位置,从而达到聚类的效果。

Kernel Bandwidth不同,生成的Kernel Density Estimation Surface不同,因此最终Clustering的结果也不同。使用Small Kernel BandWidth,KDE Surface的峰值会比较分散,生成的Cluster也比较多;反之,使用Large Kernel BandWidth,会生成宽而平滑的KDE Surface,所有的点就聚合到同一个Peak,形成的Cluster也比较集中。

BandWidth=2时,生成三个KDE Surface Peak,从而产生三个Cluster.

BandWidth=0.8时,生成了多个KDE Surface Peak,也就生成多个的Cluster

在不同的问题领域和应用场景下,什么是合理的Cluster没有统一的标准。Mean Shift的控制参数(Kernel Bandwidth),可以很容易地针对不同的应用进行合理的调整。

2.1 单点的Mean Shift的流程

Step 1:对于给定点

,Compute Mean Shift Vector:

其中

.

Step 2: Translate Density Estimation Window:

Step 3: Iterate Step 1 and Step 2 until convergence.

Mean Shift的执行流程如下:

Mean Shift的算法执行过程

2.2 Mean Shift的加速策略

Mean Shift的计算复杂度非常高,尤其在点集数量巨大的情况下,其耗时是令人难以忍受的,可以有一些加速的方法。

1、如下图蓝点所示,将End Point附近半径范围内的所有点都归属为与End Point相同的Cluster。

2. 将Mean Shift移动路径上的r/c范围内的所有Point都归属于与End Point相同的Cluster。

2.3 Cluster合并

当点集中的所有点都完成Mean Shift之后,可以对Cluster进行一些合并。当两个Cluster的Center距离小于阈值,则将两个Cluster进行合并。

3. Mean Shift在图像分割领域的应用

Mean Shift的一个很好的应用是图像分割,图像分割的目标是将图像分割成具有语义意义的区域,这个目标可以通过聚类图像中的像素来实现。

Step 1:将图像表示为空间中的点。一种简单的方法是使用红色、绿色和蓝色像素值将每个像素映射到三维RGB空间中的一个点(如下图所示)。

Step 2:对获取的点集执行Mean Shift。下图的动画演示了Mean Shift算法运行时点的聚合过程(使用Gaussian Kernel,BandWidth=25)。

Step 3: 对所有点聚合后的结果如下:

其它图像聚类分割效果

一种基于聚类的高级通用分割技术

4. Python代码实现

欧式距离计算:

代码语言:javascript复制
def euclidean_dist(pointA, pointB):
    if(len(pointA) != len(pointB)):
        raise Exception("expected point dimensionality to match")
    total = float(0)
    for dimension in range(0, len(pointA)):
        total  = (pointA[dimension] - pointB[dimension])**2
    return math.sqrt(total)

Gaussian Kernal函数:

代码语言:javascript复制
def gaussian_kernel(distance, bandwidth):
    euclidean_distance = np.sqrt(((distance)**2).sum(axis=1))
    val = (1/(bandwidth*math.sqrt(2*math.pi))) * np.exp(-0.5*((euclidean_distance / bandwidth))**2)
    return val

Multivariate Gaussian Kernel函数:

代码语言:javascript复制
def multivariate_gaussian_kernel(distances, bandwidths):

    # Number of dimensions of the multivariate gaussian
    dim = len(bandwidths)

    # Covariance matrix
    cov = np.multiply(np.power(bandwidths, 2), np.eye(dim))

    # Compute Multivariate gaussian (vectorized implementation)
    exponent = -0.5 * np.sum(np.multiply(np.dot(distances, np.linalg.inv(cov)), distances), axis=1)
    val = (1 / np.power((2 * math.pi), (dim/2)) * np.power(np.linalg.det(cov), 0.5)) * np.exp(exponent)

    return val

单点Mean Shift的过程:

代码语言:javascript复制
    def _shift_point(self, point, points, kernel_bandwidth):
        # from http://en.wikipedia.org/wiki/Mean-shift
        points = np.array(points)

        # numerator
        point_weights = self.kernel(point-points, kernel_bandwidth)
        tiled_weights = np.tile(point_weights, [len(point), 1])
        # denominator
        denominator = sum(point_weights)
        shifted_point = np.multiply(tiled_weights.transpose(), points).sum(axis=0) / denominator
        return shifted_point

        # ***************************************************************************
        # ** The above vectorized code is equivalent to the unrolled version below **
        # ***************************************************************************
        # shift_x = float(0)
        # shift_y = float(0)
        # scale_factor = float(0)
        # for p_temp in points:
        #     # numerator
        #     dist = ms_utils.euclidean_dist(point, p_temp)
        #     weight = self.kernel(dist, kernel_bandwidth)
        #     shift_x  = p_temp[0] * weight
        #     shift_y  = p_temp[1] * weight
        #     # denominator
        #     scale_factor  = weight
        # shift_x = shift_x / scale_factor
        # shift_y = shift_y / scale_factor
        # return [shift_x, shift_y]

Cluster的聚类过程:

代码语言:javascript复制
def cluster(self, points, kernel_bandwidth, iteration_callback=None):
        if(iteration_callback):
            iteration_callback(points, 0)
        shift_points = np.array(points)
        max_min_dist = 1
        iteration_number = 0

        still_shifting = [True] * points.shape[0]
        while max_min_dist > MIN_DISTANCE:
            # print max_min_dist
            max_min_dist = 0
            iteration_number  = 1
            for i in range(0, len(shift_points)):
                if not still_shifting[i]:
                    continue
                p_new = shift_points[i]
                p_new_start = p_new
                p_new = self._shift_point(p_new, points, kernel_bandwidth)
                dist = ms_utils.euclidean_dist(p_new, p_new_start)
                if dist > max_min_dist:
                    max_min_dist = dist
                if dist < MIN_DISTANCE:
                    still_shifting[i] = False
                shift_points[i] = p_new
            if iteration_callback:
                iteration_callback(shift_points, iteration_number)
        point_grouper = pg.PointGrouper()
        group_assignments = point_grouper.group_points(shift_points.tolist())
        return MeanShiftResult(points, shift_points, group_assignments)

Group聚类的过程:

代码语言:javascript复制
class PointGrouper(object):
    def group_points(self, points):
        group_assignment = []
        groups = []
        group_index = 0
        for point in points:
            nearest_group_index = self._determine_nearest_group(point, groups)
            if nearest_group_index is None:
                # create new group
                groups.append([point])
                group_assignment.append(group_index)
                group_index  = 1
            else:
                group_assignment.append(nearest_group_index)
                groups[nearest_group_index].append(point)
        return np.array(group_assignment)

    def _determine_nearest_group(self, point, groups):
        nearest_group_index = None
        index = 0
        for group in groups:
            distance_to_group = self._distance_to_group(point, group)
            if distance_to_group < GROUP_DISTANCE_TOLERANCE:
                nearest_group_index = index
            index  = 1
        return nearest_group_index

    def _distance_to_group(self, point, group):
        min_distance = sys.float_info.max
        for pt in group:
            dist = ms_utils.euclidean_dist(point, pt)
            if dist < min_distance:
                min_distance = dist
        return min_distance

参考链接:

http://vision.stanford.edu/teaching/cs131_fall1617/lectures/lecture13_kmeans_mean_shift_cs131_2016 https://spin.atomicobject.com/2015/05/26/mean-shift-clustering/ https://github.com/mattnedrich/MeanShift_py

0 人点赞