作者 | 梁唐
大家好,我是梁唐。
今天我们继续来聊线性回归,我们上一期聊了线性回归的背景,今天来聊聊它的求解。
原创技术文章,恳求大家多多支持,给我一点继续更新的动力。
线性回归怎么用?
上一期我们虽然聊了线性回归的背景,但却没有说它怎么使用。虽然我们学习的是模型的原理,但不了解使用场景有的时候会让理论的学习变得很困难。所以有必要花一点篇幅先来简单说明一下线性回归的使用场景。
我们之前的文章说了,线性回归本质上是一个线性方程组:
如果我们写成矩阵相乘的形式则是:
那么,什么情况下我们需要用到这个方程呢?
很简单,当X很容易求,而Y不容易求的时候,这里的不容易求往往指的是后验。也就是说无法事先知道,举个例子,比如说波士顿房价预测问题。一个房屋的售卖价格是未知,显然要是知道那就不用预测了。但是对于像是房间数量,房屋面积,距离地铁站的距离这些变量是可以通过测量、计算等方法得到的。
我们要做的其实是根据参数矩阵W,以及我们采集到的特征信息X,让模型预测后验结果Y。这句话看不懂也没关系,简单理解,就是有了W之后,我们就可以根据X来预测Y。
所以问题的核心就是W,有了W就可以预测Y,而且要想办法找到最好的W,使得预测出来的Y最接近。
最小二乘法
想要有W,那么怎么才能得到W呢?
首先我们要先采集一批历史数据当做参照,所谓的历史数据,也就是包含了X和Y的完整数据。
X矩阵包含了所有的自变量,被称为特征矩阵。每一个变量被称为一个特征(feature)。Y是结果矩阵,被称为label。我们要做的就是根据已有的X和Y,想办法求出W。
前文说过,W矩阵是一个1 x n的矩阵。在n很小的时候,我们当然可以人工来算。实际上早年在机器学习诞生之前,很多场景的确都是人工来算的。当时的数学分析师,很大一部分工作就是来手工算各个模型的参数。计算量巨大,巨枯燥巨无聊……
多说一句,凡是和数据分析扯上关系的岗位往往都是这样,工作量很大,很枯燥很无聊……大家千万不要觉得数据分析师和大数据打交道,很高大上,这些大部分情况是错觉。
当n很大的时候,显然就不能人工来计算了。我们必须要通过程序来算,想要通过程序来算,就得先设计一个算法。好在这个算法并不难设计,因为数学界的大佬已经替我们想好了,这个算法就是最小二乘法。
最小二乘法当初是天文学家拿来计算小行星轨道的,最小二乘法这个名字不太好理解,其实翻译成最小平方法更好理解一些。本质上就是一个衡量模型误差的思想。
假设我们现在模型的预测结果是hat{y} ,搜集到的label矩阵是y 。显然由于种种误差的存在,这两个结果不可能完全相等。但问题来了,不相等可以,那么怎么衡量结果好坏呢?比较容易想到可以做差,如果得到的差值越小,那么就说明模型效果越好,也就是对应的W越好。
我们把差值写成公式就是:
这里的m是样本的数量,模型的整体误差e就是每一个样本的差值之和。但是这种做法有一个小bug,就是绝对值的计算非常麻烦,主要是不方便求导。至于为什么要求导,我们后面再说。
绝对值很麻烦怎么办呢?很简单,我们可以平方求和:
因为我们并不需要误差e,只不过拿它当做衡量模型效果的指标,所以求平方并不影响。所以我们要做的就是求出误差e最小时对应的参数矩阵W。
那怎么根据这个误差的式子求对应的W呢?数学好一点的同学估计已经猜到了,没错,就是求导。这虽然不是一个二次方程,但也可以当做二次方程一样,进行求导求极值。
求导之前,我们先把公式写全:
,接着我们先对这个公式做一个简单的变形:我们想办法把b处理掉,让式子尽可能简洁。
首先,我们在X当中增加一列1,也就是将X变成m * (n 1)的矩阵,它的第一列是常数1,新的矩阵写成
同样,我们在
中也增加一列,它的第一列写成b,我们将新的矩阵写成
,我们可以得到:
之后,我们把
和
代入原式得到:
这里的m是样本的数量,是一个常数,我们除以这个系数并不会影响
的取值。这个
就是我们常说的模型的损失函数。
这里的损失其实是误差的意思,损失函数的值越小,说明模型的误差越小,和真实结果越接近。为了方便书写,我们把
还是写成
。
我们计算
对
的导数:
我们令导数等于0,由于m是常数,可以消掉,得到:
上面的推导过程初看可能觉得复杂,但实际上并不困难。只是一个简单的求偏导的推导,我们就可以写出最优的
的取值。
从这个公式来看并不难计算,实际上是否真的是这么简单呢?我们试着用代码来实验一下。
代码实验
为了简单起见,我们针对最简单的场景:样本只有一个特征,我们首先先试着自己生产一批数据:
代码语言:javascript复制import numpy as np
X = 2 * np.random.rand(100, 1)
y = 4 3 * X np.random.randn(100, 1)
import matplotlib.pyplot as plt
plt.scatter(X, y)
我们先生成一百个0~2范围内的x,然后
,为了模拟真实存在误差的场景,我们再对y加上一个0~1范围内的误差。
我们把上面的点通过plt画出来可以得到这样一张图:
接着,我们就可以通过上面的公式直接计算出theta 的取值了:
代码语言:javascript复制def get_theta(x, y):
m = len(y)
# x中新增一列常数1
x = np.c_[np.ones(m).T, x]
# 通过公式计算theta
theta = np.dot(np.dot(np.linalg.inv(np.dot(x.T, x)), x.T), y)
return theta
我们传入x和y得到theta,打印出来看,会发现和我们设置的非常接近:
最后,我们把模型拟合的结果和真实样本的分布都画在一张图上:
代码语言:javascript复制# 我们画出模型x在0到2区间内的值
X_new = np.array([[0],[2]])
# 新增一列常数1的结果
X_new_b = np.c_[np.ones((2, 1)), X_new]
# 预测的端点值
y_predict = X_new_b.dot(theta)
# 画出模型拟合的结果
plt.plot(X_new, y_predict,"r-")
# 画出原来的样本
plt.scatter(X,y)
plt.show()
得到的结果如下:
从结果上来看模型的效果非常棒,和我们的预期非常吻合,并且实现的代码实在是太简单了,只有短短几行。
但先别高兴得太早,有一点必须说清楚,虽然上面的例子非常完美,但是实际场景当中,大家并不会这么干。而是会采用其他的办法来求解W。
这个就很奇怪了,明明三行代码可以求出结果,为什么非要用其他办法绕个弯子算呢?
原因其实很简单,如果你线性代数还没还给老师的话,应该很容易就能发现。
首先是我们计算Theta 的公式当中用到了逆矩阵的操作。线性代数当中说过,只有满秩矩阵才有逆矩阵。如果
是奇异矩阵,那么它是没有逆矩阵的,自然这个公式也用不了了。
当然这个问题并不是不能解决的,X^T cdot X 是奇异矩阵的条件是矩阵X 当中存在一行或者一列全为0。我们通过特征预处理,是可以避免这样的事情发生的。所以样本无法计算逆矩阵只是原因之一,并不是最关键的问题。
最关键的问题是复杂度,虽然我们看起来上面核心的代码只有一行,但实际上由于我们用到了逆矩阵的计算,它背后的开销非常大。
X^T cdot X 的结果是一个n * n的矩阵,这里的n是特征的维度。
这样一个矩阵计算逆矩阵的复杂度大概在n^{2.4} 到n^3 之间。当n很小的时候当然没有关系,如果n很大,则求解起来非常耗时。在现实场景当中,我们的n往往动辄好几千,甚至好几万、好几十万。显然在这么大的量级下,想要求解逆矩阵非常非常困难,甚至是几乎不可能的。
正是因为以上这些原因,所以通常我们并不会使用直接通过公式计算的方法来求模型的参数。
那么,既然直接计算开销太大,我们又该如何求解呢?
别着急,我们在下篇文章为大家揭晓。