Mean Shift算法建立在核密度估计(kernel density estimation,KDE)的基础之上,它假设数据点集是从Probability Distribution中采样获取的,Kernel Density Estimation是从数据点集估计Probability Distribution的非参数估计方法。
1.Kernel Density Estimation
给定n个数据点
,使用Radially Symmetric Kernel的Multivariate Kernel Density Estimate的形式如下:
h是BandWidth,不同BandWidth,聚类的效果也会不同。Radially Symmetric Kernel的定义如下:
是Normalization Constant。
如上图所示,红色点是待估计数据点集,Kernel Density Estimation的工作原理是在数据点集的每个Point放置一个Kernel(Kernel的实质是加权函数,Kernel的种类很多,比较常用的Gaussian Kernel),将所有的单个Kernel加起来就生成Probability Surface。所使用的Kernel BandWidth参数不用,生成的密度函数将有所不同。
surface plot using a Gaussian kernel with a kernel bandwidth of 2
contour plot of the surface using a Gaussian kernel with a kernel bandwidth of 2
2.Mean Shift
Mean Shift想法很简单:迭代的将将所有的Point汇集到Kernel Density Estimation Surface上距离最近的Peak位置,从而达到聚类的效果。
Kernel Bandwidth不同,生成的Kernel Density Estimation Surface不同,因此最终Clustering的结果也不同。使用Small Kernel BandWidth,KDE Surface的峰值会比较分散,生成的Cluster也比较多;反之,使用Large Kernel BandWidth,会生成宽而平滑的KDE Surface,所有的点就聚合到同一个Peak,形成的Cluster也比较集中。
BandWidth=2时,生成三个KDE Surface Peak,从而产生三个Cluster.
BandWidth=0.8时,生成了多个KDE Surface Peak,也就生成多个的Cluster
在不同的问题领域和应用场景下,什么是合理的Cluster没有统一的标准。Mean Shift的控制参数(Kernel Bandwidth),可以很容易地针对不同的应用进行合理的调整。
2.1 单点的Mean Shift的流程
Step 1:对于给定点
,Compute Mean Shift Vector:
。
其中
.
Step 2: Translate Density Estimation Window:
Step 3: Iterate Step 1 and Step 2 until convergence.
Mean Shift的执行流程如下:
Mean Shift的算法执行过程
2.2 Mean Shift的加速策略
Mean Shift的计算复杂度非常高,尤其在点集数量巨大的情况下,其耗时是令人难以忍受的,可以有一些加速的方法。
1、如下图蓝点所示,将End Point附近半径范围内的所有点都归属为与End Point相同的Cluster。
2. 将Mean Shift移动路径上的r/c范围内的所有Point都归属于与End Point相同的Cluster。
2.3 Cluster合并
当点集中的所有点都完成Mean Shift之后,可以对Cluster进行一些合并。当两个Cluster的Center距离小于阈值,则将两个Cluster进行合并。
3. Mean Shift在图像分割领域的应用
Mean Shift的一个很好的应用是图像分割,图像分割的目标是将图像分割成具有语义意义的区域,这个目标可以通过聚类图像中的像素来实现。
Step 1:将图像表示为空间中的点。一种简单的方法是使用红色、绿色和蓝色像素值将每个像素映射到三维RGB空间中的一个点(如下图所示)。
Step 2:对获取的点集执行Mean Shift。下图的动画演示了Mean Shift算法运行时点的聚合过程(使用Gaussian Kernel,BandWidth=25)。
Step 3: 对所有点聚合后的结果如下:
其它图像聚类分割效果
一种基于聚类的高级通用分割技术
4. Python代码实现
欧式距离计算:
代码语言:javascript复制def euclidean_dist(pointA, pointB):
if(len(pointA) != len(pointB)):
raise Exception("expected point dimensionality to match")
total = float(0)
for dimension in range(0, len(pointA)):
total = (pointA[dimension] - pointB[dimension])**2
return math.sqrt(total)
Gaussian Kernal函数:
代码语言:javascript复制def gaussian_kernel(distance, bandwidth):
euclidean_distance = np.sqrt(((distance)**2).sum(axis=1))
val = (1/(bandwidth*math.sqrt(2*math.pi))) * np.exp(-0.5*((euclidean_distance / bandwidth))**2)
return val
Multivariate Gaussian Kernel函数:
代码语言:javascript复制def multivariate_gaussian_kernel(distances, bandwidths):
# Number of dimensions of the multivariate gaussian
dim = len(bandwidths)
# Covariance matrix
cov = np.multiply(np.power(bandwidths, 2), np.eye(dim))
# Compute Multivariate gaussian (vectorized implementation)
exponent = -0.5 * np.sum(np.multiply(np.dot(distances, np.linalg.inv(cov)), distances), axis=1)
val = (1 / np.power((2 * math.pi), (dim/2)) * np.power(np.linalg.det(cov), 0.5)) * np.exp(exponent)
return val
单点Mean Shift的过程:
代码语言:javascript复制 def _shift_point(self, point, points, kernel_bandwidth):
# from http://en.wikipedia.org/wiki/Mean-shift
points = np.array(points)
# numerator
point_weights = self.kernel(point-points, kernel_bandwidth)
tiled_weights = np.tile(point_weights, [len(point), 1])
# denominator
denominator = sum(point_weights)
shifted_point = np.multiply(tiled_weights.transpose(), points).sum(axis=0) / denominator
return shifted_point
# ***************************************************************************
# ** The above vectorized code is equivalent to the unrolled version below **
# ***************************************************************************
# shift_x = float(0)
# shift_y = float(0)
# scale_factor = float(0)
# for p_temp in points:
# # numerator
# dist = ms_utils.euclidean_dist(point, p_temp)
# weight = self.kernel(dist, kernel_bandwidth)
# shift_x = p_temp[0] * weight
# shift_y = p_temp[1] * weight
# # denominator
# scale_factor = weight
# shift_x = shift_x / scale_factor
# shift_y = shift_y / scale_factor
# return [shift_x, shift_y]
Cluster的聚类过程:
代码语言:javascript复制def cluster(self, points, kernel_bandwidth, iteration_callback=None):
if(iteration_callback):
iteration_callback(points, 0)
shift_points = np.array(points)
max_min_dist = 1
iteration_number = 0
still_shifting = [True] * points.shape[0]
while max_min_dist > MIN_DISTANCE:
# print max_min_dist
max_min_dist = 0
iteration_number = 1
for i in range(0, len(shift_points)):
if not still_shifting[i]:
continue
p_new = shift_points[i]
p_new_start = p_new
p_new = self._shift_point(p_new, points, kernel_bandwidth)
dist = ms_utils.euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:
max_min_dist = dist
if dist < MIN_DISTANCE:
still_shifting[i] = False
shift_points[i] = p_new
if iteration_callback:
iteration_callback(shift_points, iteration_number)
point_grouper = pg.PointGrouper()
group_assignments = point_grouper.group_points(shift_points.tolist())
return MeanShiftResult(points, shift_points, group_assignments)
Group聚类的过程:
代码语言:javascript复制class PointGrouper(object):
def group_points(self, points):
group_assignment = []
groups = []
group_index = 0
for point in points:
nearest_group_index = self._determine_nearest_group(point, groups)
if nearest_group_index is None:
# create new group
groups.append([point])
group_assignment.append(group_index)
group_index = 1
else:
group_assignment.append(nearest_group_index)
groups[nearest_group_index].append(point)
return np.array(group_assignment)
def _determine_nearest_group(self, point, groups):
nearest_group_index = None
index = 0
for group in groups:
distance_to_group = self._distance_to_group(point, group)
if distance_to_group < GROUP_DISTANCE_TOLERANCE:
nearest_group_index = index
index = 1
return nearest_group_index
def _distance_to_group(self, point, group):
min_distance = sys.float_info.max
for pt in group:
dist = ms_utils.euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
参考链接:
http://vision.stanford.edu/teaching/cs131_fall1617/lectures/lecture13_kmeans_mean_shift_cs131_2016 https://spin.atomicobject.com/2015/05/26/mean-shift-clustering/ https://github.com/mattnedrich/MeanShift_py