注:本文有助于理解 SVM 和核函数的含义,更多关于机器学习的内容,请参阅:
http://math.itdiffer.com/machinelearning.html,或点击【阅读原文】查阅。
理解 SVM 的核函数的实际作用
在 SVM 中引入核函数,用它处理非线性数据,即:将数据映射到高维空间中,使数据在其中变为线性的,然后应用一个简单的线性 SVM。听起来很复杂,在某种程度上确实如此。然而,尽管理解核函数的工作原理可能很困难,但它所要实现的目标很容易把握。
线性 SVM
先简要说明一下 SVM 的一般工作原理。我们可以将 SVM 用于分类和回归任务,但在本文中,将重点关注分类。首先考虑线性可分的二分类数据,按照下面的方式创建这个数据集:
引入所有相关的模块。
代码语言:javascript复制from sklearn.datasets import make_blobs, make_moons
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.svm import LinearSVC, SVC
import matplotlib.pyplot as plt
%matplotlib inline
from mlxtend.plotting import plot_decision_regions
import numpy as np
创建有一个具有两个类别的数据集。
代码语言:javascript复制X, y = make_blobs(n_samples=100, centers=2, n_features=2, random_state=42)
使用 scikit-learn 拟合一个线性 SVM。注意,我们在训练模型之前对数据进行了归一化处理,因为 SVM 对特征的尺度非常敏感。
代码语言:javascript复制pipe = make_pipeline(StandardScaler(), LinearSVC(C=1, loss="hinge"))
pipe.fit(X, y)
编写一个实现数据可视化的函数。
代码语言:javascript复制def plot_svm(clf, X):
decision_function = pipe.decision_function(X)
support_vector_indices = np.where((2 * y - 1) * decision_function <= 1)[0]
support_vectors = X[support_vector_indices]
plt.figure(figsize=(8, 8))
plot_decision_regions(X, y, clf=pipe, legend=0,
colors="skyblue,xkcd:goldenrod")
plt.scatter(support_vectors[:, 0], support_vectors[:, 1], s=200,
linewidth=1, facecolors='none', edgecolors='r')
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 50),
np.linspace(ylim[0], ylim[1], 50))
Z = pipe.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1], alpha=0.5,
linestyles=['--', '-', '--'])
用自定义的函数 plot_svm()
对前述模型 pip
和数据集 X
绘图,并输出下图所示图像。
plot_svm(pipe, X)
输出图像:
由图可知,在两个类别的数据之间,可以无穷多条线,将二者分开。而 SVM 拟合一条很特别的线,它是图中所示的用虚线标记的无数据点分布的“走廊”的中线——这个“走廊”称之为间隔,并且两个类之间的间隔要尽可能宽。间隔中间的实线也就距离两个类别的数据尽可能远。这样训练的模型,才能很好地推广到新的数据中。
上图中用红色圆圈标记出的样本点,在间隔边界上,这个数据点所对应的向量,称为支持向量,因为它们支持或决定了间隔的位置。即使我们在间隔外增加一些新的样本点,它也不会改变此间隔的位置。
注意:这是一个硬边缘分类的例子,它意味着:不允许任何样本进入该间隔。此外,我们可以做一个软边缘分类:允许一些样本点进入该间隔,但不要太多,同时使间隔更宽。这样做可以更有效地对付异常值,并且可以通过 LinearSVC()
中的参数 C
来控制。
上面举例中的数据,其实比较理想化,真实的数据一般不是那么泾渭分明的,例如:
代码语言:javascript复制X, y = make_moons(n_samples=100, noise=0.1, random_state=42)
pipe = make_pipeline(StandardScaler(), LinearSVC(C=1, loss="hinge"))
pipe.fit(X, y)
plot_svm(pipe, X)
再用之前的方法,得到的结果就很不好了。怎么改进?
映射到更高维度
在讨论核及其作用之前,先了解一种强大思想观点:在高维空间中,数据更有可能线性可分。
代码语言:javascript复制x1 = np.array([-3, -2, -1, 0, 1, 2, 3])
x2 = x1 ** 2
aux = np.zeros(shape=x1.shape)
y = np.array([0, 0, 1, 1, 1 ,0, 0])
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
plt.scatter(x1[y == 1], aux[y == 1], c=['xkcd:lightish blue'],
edgecolor="black", s=250)
plt.scatter(x1[y == 0], aux[y == 0], c=['xkcd:terra cotta'],
edgecolor="black", s=250)
plt.axis("equal")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.yticks([])
plt.xlabel("x1", fontsize=15)
plt.title("One feature: data linearly unseparable", fontsize=15)
plt.subplot(1, 2, 2)
plt.scatter(x1[y == 1], x2[y == 1], c=['xkcd:lightish blue'],
edgecolor="black", s=250)
plt.scatter(x1[y == 0], x2[y == 0], c=['xkcd:terra cotta'],
edgecolor="black", s=250)
plt.axis("equal")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.xlabel("x1", fontsize=15)
plt.ylabel("x2 = x1^2", fontsize=15)
plt.plot([-4, 4], [2.5, 2.5], linestyle='--', lw=3, color='black')
plt.title("Two features: data linearly separable", fontsize=15)
plt.tight_layout(7)
plt.show()
输出:
如上图中所示,当只有一个特征 x1 时,这些点不能用一条直线分割开。再添加另一个特征 x2(x2 等于 x1 的平方),就可以很容易地将这两类数据分开了。
核到底是什么
核是一种向数据添加更多特征的巧妙方法,目的是使数据线性可分。其巧妙之处在于:它实际上并没有添加特征(这会使模型变慢),而是使用了一些神奇的数学属性(这超出了本文的讨论范围,在参考资料 [2] 中有相关数学知识的详细阐述)。这使我们获得与实际添加这些特性完全相同的结果,而又不降低模型的速度。
下面分别介绍两种流行的核:多项式核和高斯径向基函数核(RBF)。它们(假装)添加的特征类型不同。
多项式核
增加更多特征的一种方法是在一定程度上使用原有特征的多项式组合。例如,有两个特征 A 和 B,一个 2 次的多项式将产生 6 个特征: 1(指数为 0 的任何特征),A, B, A²,B²,和 AB。我们可以使用 scikit-learn 的PolynomialFeatures()
很容易地手动添加这些特征:
X, y = make_moons(n_samples=100, noise=0.1, random_state=42)
pipe = make_pipeline(StandardScaler(),
PolynomialFeatures(degree=3),
LinearSVC(C=5))
pipe.fit(X, y)
plot_svm(pipe, X)
输出结果:
也可以用下面的方式,更简便地使用多项式核。
代码语言:javascript复制X, y = make_moons(n_samples=100, noise=0.1, random_state=42)
pipe = make_pipeline(StandardScaler(), SVC(kernel="poly", degree=3, C=5, coef0=1))
pipe.fit(X, y)
plot_svm(pipe, X)
输出结果。
以上两种方法得到的结果类似。由此可知,使用核函数的好处在于,通过指定较高的指数值(上例中 degree=3
),提高了数据在高维空间中实现线性可分的可能性,且不降低模型的训练时间。
对于上面通过 make_moos()
创建的“月牙形”数据,从散点图可以清楚地看出,3 次的多项式就足以支持分类任务了。然而,对于更复杂的数据集,可能需要使用更高的指数。这就是核技巧威力之所在。
高斯RBF核
另一种用于增加数据特征的方法就是向其中增加相似性特征。相似性特征度量了现有特征的值与一个中心的距离。
例如:有一个数据集,这个数据集只有一个特征 x1。我们想要新增两个相似特征,就选择两个“中心”,比如,从这个单一特征中选择的两个参考值作为“中心”,分别是 -1 和 1 为例(如下图的左图,图中的 landmark 即所选择的“中心”)。然后,对于 x1 的每个值,计算它距离第一个中心的 -1 的距离。所有计算结果,就构成了新的相似性特征 x2。然后进行同样的操作,将 x1 的值与第二个中心 1 进行比较,得到新的特征 x3。现在我们甚至不需要最初的特征 x1 了!这两个新的相似性特征就能够将数据分离。
每一个样本点到中心的距离,一种常用的计算法方法是使用高斯径向基函数(RBF)定义:
(1)式中的
为数据集样本(观察值),
是一个参数,此处令
。以上图中左侧图为例,根据(1)式,计算
与
的距离:
这个值作为 x2 特征的值。
同样,计算
与
的距离,得:
这个值作为 x3 特征的值。
于是将一维特征 x1 中的值
,根据高斯 RBF 核,升到二维特征 x2 和 x3,对应的数值为
,将此数据在二维坐标系中用点表示出来(如上图中右侧的图示)。
用同样方法,将一维特征 x1 中的其他各点,都变化为二维特征的数据,最终得到上图中右侧图示结果。从图中我们可以直接观察到,到维度提升之后,各个数据点能够用线性方法给予分类了。
代码语言:javascript复制def gaussian_rbf (x, landmark, gamma):
return np.exp(-gamma * (x - landmark) ** 2)
x1 = np.array([-3, -2, -1, 0, 1, 2, 3])
landmarks = [-1, 1]
x2 = np.array([gaussian_rbf(x, landmarks[0], 0.3) for x in x1])
x3 = np.array([gaussian_rbf(x, landmarks[1], 0.3) for x in x1])
aux = np.zeros(shape=x1.shape)
y = np.array([0, 0, 1, 1, 1 ,0, 0])
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
plt.scatter(x1[y == 1], aux[y == 1], c=['xkcd:lightish blue'],
edgecolor="black", s=250)
plt.scatter(x1[y == 0], aux[y == 0], c=['xkcd:terra cotta'],
edgecolor="black", s=250)
plt.plot([landmarks[0], landmarks[0]], [2.5, 0.5], linestyle='--', lw=3, color='gray')
plt.plot([landmarks[1], landmarks[1]], [2.5, 0.5], linestyle='--', lw=3, color='gray')
plt.annotate("1st landmark", (landmarks[0] - 1.2, 2.8), fontsize=12, color='gray')
plt.annotate("2nd landmark", (landmarks[1] - 0.8, 2.8), fontsize=12, color='gray')
plt.axis("equal")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.yticks([])
plt.xlabel("x1", fontsize=15)
plt.title("Original single feature:ndata linearly unseparable", fontsize=15)
plt.subplot(1, 2, 2)
plt.scatter(x2[y == 1], x3[y == 1], c=['xkcd:lightish blue'],
edgecolor="black", s=250)
plt.scatter(x2[y == 0], x3[y == 0], c=['xkcd:terra cotta'],
edgecolor="black", s=250)
plt.axis("equal")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.xlabel("x2 = distance from 1st landmark", fontsize=15)
plt.ylabel("x3 = distance from 2nd landmark", fontsize=15)
plt.plot([0, 1], [1, 0], linestyle='--', lw=3, color='black')
plt.title("Two similarity features:ndata linearly separable", fontsize=15)
plt.tight_layout(7)
plt.show()
输出结果:
在上面的例子中,所选择的一维特征中的两个参考值(“中心”,(1)式中的 landmark
),其实有运气的成分。在实践中,可能需要大量的这样的“中心”,从而得到许多新的相似性特征。但是这样的操作将大大降低 SVM的速度——除非我们借助核技巧!
类似于多项式核,RBF 核看起来好像是对原始特征的每个值上都要指定一个“中心”,但实际上不需要真的这样做。让我们用“月牙形”的数据来验证一下。
代码语言:javascript复制X, y = make_moons(n_samples=100, noise=0.1, random_state=42)
pipe = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma=0.3, C=5))
pipe.fit(X, y)
plot_svm(pipe, X)
输出结果:
从图中可以看出,决策边界看起来相当不错,但是,注意一些分类错误的样本。我们可以通过调整
参数来解决问题。
参数可以充当正则项——参数越小,决策边界越平滑,但要防止过拟合。上面的情况下,实际上是欠拟合,所以,要令
。
代码语言:javascript复制X, y = make_moons(n_samples=100, noise=0.1, random_state=42)
pipe = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma=0.5, C=5))
pipe.fit(X, y)
plot_svm(pipe, X)
plt.show()
输出结果:
现在,就实现了对“月牙形”数据的正确分类。
总结
- SVM通过寻找与数据尽可能远的线性决策边界来进行分类。SVM在处理线性可分数据时很有效,但在其他方面效果极差。
- 为了使非线性数据变得线性可分(从而便于使用 SVM),我们可以向数据中添加更多的特征,因为在高维空间中,数据线性可分的概率增加了。
- 对现有特征的多项式组合,即多项式特征,以及通过样本和参考值距离所得到的相似特征,是数据集中常用的所增加的新特征。
- 如果增加了太多的特征,可能会减慢模型的速度。
- 核技巧是一种明智的策略,它利用了一些数学特性,以便得到相同的结果,貌似添加了额外的特征,但执行速度并没有减慢。
- 多项式和 RBF 核(假装)分别添加了多项式和相似性特征。
参考资料
[1] Michał Oleszak. SVM Kernels: What Do They Actually Do?[DB/OL]. https://towardsdatascience.com/svm-kernels-what-do-they-actually-do-56ce36f4f7b8 ,2022.10.17.
[2] 齐伟. 机器学习数学基础[M]. 电子工业出版社, 2022.