三行代码求出线性回归,但为什么大家不这么用呢?

2022-08-26 18:20:39 浏览数 (1)

作者 | 梁唐

大家好,我是梁唐。

今天我们继续来聊线性回归,我们上一期聊了线性回归的背景,今天来聊聊它的求解。

原创技术文章,恳求大家多多支持,给我一点继续更新的动力。

线性回归怎么用?

上一期我们虽然聊了线性回归的背景,但却没有说它怎么使用。虽然我们学习的是模型的原理,但不了解使用场景有的时候会让理论的学习变得很困难。所以有必要花一点篇幅先来简单说明一下线性回归的使用场景。

我们之前的文章说了,线性回归本质上是一个线性方程组:

begin{aligned} y_0 &= w_0x_{(0,0)} w_1x_{(0,1)} cdots w_nx_{(0,n)} b \ y_1 &= w_0x_{(1,0)} w_1x_{(1,1)} cdots w_nx_{(1,n)} b \ vdots \ y_m &= w_0x_{(m,0)} w_1x_{(m,1)} cdots w_nx_{(m,n)} b end{aligned}

如果我们写成矩阵相乘的形式则是:

Y=XW^T b

那么,什么情况下我们需要用到这个方程呢?

很简单,当X很容易求,而Y不容易求的时候,这里的不容易求往往指的是后验。也就是说无法事先知道,举个例子,比如说波士顿房价预测问题。一个房屋的售卖价格是未知,显然要是知道那就不用预测了。但是对于像是房间数量,房屋面积,距离地铁站的距离这些变量是可以通过测量、计算等方法得到的。

我们要做的其实是根据参数矩阵W,以及我们采集到的特征信息X,让模型预测后验结果Y。这句话看不懂也没关系,简单理解,就是有了W之后,我们就可以根据X来预测Y。

所以问题的核心就是W,有了W就可以预测Y,而且要想办法找到最好的W,使得预测出来的Y最接近。

最小二乘法

想要有W,那么怎么才能得到W呢?

首先我们要先采集一批历史数据当做参照,所谓的历史数据,也就是包含了X和Y的完整数据。

X矩阵包含了所有的自变量,被称为特征矩阵。每一个变量被称为一个特征(feature)。Y是结果矩阵,被称为label。我们要做的就是根据已有的X和Y,想办法求出W。

前文说过,W矩阵是一个1 x n的矩阵。在n很小的时候,我们当然可以人工来算。实际上早年在机器学习诞生之前,很多场景的确都是人工来算的。当时的数学分析师,很大一部分工作就是来手工算各个模型的参数。计算量巨大,巨枯燥巨无聊……

多说一句,凡是和数据分析扯上关系的岗位往往都是这样,工作量很大,很枯燥很无聊……大家千万不要觉得数据分析师和大数据打交道,很高大上,这些大部分情况是错觉。

当n很大的时候,显然就不能人工来计算了。我们必须要通过程序来算,想要通过程序来算,就得先设计一个算法。好在这个算法并不难设计,因为数学界的大佬已经替我们想好了,这个算法就是最小二乘法。

最小二乘法当初是天文学家拿来计算小行星轨道的,最小二乘法这个名字不太好理解,其实翻译成最小平方法更好理解一些。本质上就是一个衡量模型误差的思想。

假设我们现在模型的预测结果是hat{y} ,搜集到的label矩阵是y 。显然由于种种误差的存在,这两个结果不可能完全相等。但问题来了,不相等可以,那么怎么衡量结果好坏呢?比较容易想到可以做差,如果得到的差值越小,那么就说明模型效果越好,也就是对应的W越好。

我们把差值写成公式就是:

e = sum_{i=1}^m |y_i - hat{y_i}|

这里的m是样本的数量,模型的整体误差e就是每一个样本的差值之和。但是这种做法有一个小bug,就是绝对值的计算非常麻烦,主要是不方便求导。至于为什么要求导,我们后面再说。

绝对值很麻烦怎么办呢?很简单,我们可以平方求和:

e=sum_{i=1}^m (y_i - hat{y_i})^2

因为我们并不需要误差e,只不过拿它当做衡量模型效果的指标,所以求平方并不影响。所以我们要做的就是求出误差e最小时对应的参数矩阵W。

那怎么根据这个误差的式子求对应的W呢?数学好一点的同学估计已经猜到了,没错,就是求导。这虽然不是一个二次方程,但也可以当做二次方程一样,进行求导求极值。

求导之前,我们先把公式写全:

(Y-(XW^T b))^2

,接着我们先对这个公式做一个简单的变形:我们想办法把b处理掉,让式子尽可能简洁。

首先,我们在X当中增加一列1,也就是将X变成m * (n 1)的矩阵,它的第一列是常数1,新的矩阵写成

X_{new}

同样,我们在

W^T

中也增加一列,它的第一列写成b,我们将新的矩阵写成

Theta

,我们可以得到:

XW b = X_{new}Theta=left[ begin{matrix} 1 & x_{11} & x_{12} & cdots & x_{1n}\ 1 & x_{21} & x_{22} & cdots & x_{2n}\ vdots & vdots & vdots & vdots & vdots\ 1 & x_{m1} & x_{m2} & cdots & x_{mn} end{matrix} right] cdot left[ begin{matrix} b & w_1 & w_2 cdots w_n end{matrix} right]^T

之后,我们把

X_{new}

theta

代入原式得到:

J(Theta) = frac{1}{2m}sum_{i=1}^m(x_icdot theta - y_i)^2 = frac1 {2m} (X_{new}Theta - Y)^2

这里的m是样本的数量,是一个常数,我们除以这个系数并不会影响

Theta

的取值。这个

J(Theta)

就是我们常说的模型的损失函数。

这里的损失其实是误差的意思,损失函数的值越小,说明模型的误差越小,和真实结果越接近。为了方便书写,我们把

X_{new}

还是写成

X

我们计算

J(Theta)

Theta

的导数:

frac{partial J(Theta)}{partial Theta}=frac{1}{m}(XTheta - Y)^T X

我们令导数等于0,由于m是常数,可以消掉,得到:

begin{aligned} (XTheta - Y)^TX &= 0 \ X^TXTheta - X^TY &= 0 \ X^TXTheta &= X^TY \ Theta &= (X^T cdot X)^{-1}X^TY end{aligned}

上面的推导过程初看可能觉得复杂,但实际上并不困难。只是一个简单的求偏导的推导,我们就可以写出最优的

Theta

的取值。

从这个公式来看并不难计算,实际上是否真的是这么简单呢?我们试着用代码来实验一下。

代码实验

为了简单起见,我们针对最简单的场景:样本只有一个特征,我们首先先试着自己生产一批数据:

代码语言:javascript复制
import numpy as np
X = 2 * np.random.rand(100, 1)
y = 4   3 * X   np.random.randn(100, 1)

import matplotlib.pyplot as plt
plt.scatter(X, y)

我们先生成一百个0~2范围内的x,然后

y= 3x 4

,为了模拟真实存在误差的场景,我们再对y加上一个0~1范围内的误差。

我们把上面的点通过plt画出来可以得到这样一张图:

接着,我们就可以通过上面的公式直接计算出theta 的取值了:

代码语言:javascript复制
def get_theta(x, y):
    m = len(y)
    # x中新增一列常数1
    x = np.c_[np.ones(m).T, x]
    # 通过公式计算theta
    theta = np.dot(np.dot(np.linalg.inv(np.dot(x.T, x)), x.T), y)
    return theta

我们传入x和y得到theta,打印出来看,会发现和我们设置的非常接近:

最后,我们把模型拟合的结果和真实样本的分布都画在一张图上:

代码语言:javascript复制
# 我们画出模型x在0到2区间内的值
X_new = np.array([[0],[2]])
# 新增一列常数1的结果
X_new_b = np.c_[np.ones((2, 1)), X_new]
# 预测的端点值
y_predict = X_new_b.dot(theta)

# 画出模型拟合的结果
plt.plot(X_new, y_predict,"r-")
# 画出原来的样本
plt.scatter(X,y)
plt.show()

得到的结果如下:

从结果上来看模型的效果非常棒,和我们的预期非常吻合,并且实现的代码实在是太简单了,只有短短几行。

但先别高兴得太早,有一点必须说清楚,虽然上面的例子非常完美,但是实际场景当中,大家并不会这么干。而是会采用其他的办法来求解W。

这个就很奇怪了,明明三行代码可以求出结果,为什么非要用其他办法绕个弯子算呢?

原因其实很简单,如果你线性代数还没还给老师的话,应该很容易就能发现。

首先是我们计算Theta 的公式当中用到了逆矩阵的操作。线性代数当中说过,只有满秩矩阵才有逆矩阵。如果

X^T cdot X

是奇异矩阵,那么它是没有逆矩阵的,自然这个公式也用不了了。

当然这个问题并不是不能解决的,X^T cdot X 是奇异矩阵的条件是矩阵X 当中存在一行或者一列全为0。我们通过特征预处理,是可以避免这样的事情发生的。所以样本无法计算逆矩阵只是原因之一,并不是最关键的问题。

最关键的问题是复杂度,虽然我们看起来上面核心的代码只有一行,但实际上由于我们用到了逆矩阵的计算,它背后的开销非常大。

X^T cdot X 的结果是一个n * n的矩阵,这里的n是特征的维度。

这样一个矩阵计算逆矩阵的复杂度大概在n^{2.4}n^3 之间。当n很小的时候当然没有关系,如果n很大,则求解起来非常耗时。在现实场景当中,我们的n往往动辄好几千,甚至好几万、好几十万。显然在这么大的量级下,想要求解逆矩阵非常非常困难,甚至是几乎不可能的。

正是因为以上这些原因,所以通常我们并不会使用直接通过公式计算的方法来求模型的参数。

那么,既然直接计算开销太大,我们又该如何求解呢?

别着急,我们在下篇文章为大家揭晓。

0 人点赞