今天聊最小二乘法的实现。
都知道线性回归模型要求解权重向量w,最传统的做法就是使用最小二乘法。根据在scikit-learn的文档,模型sklearn.linear_model.LinearRegression,使用的就是最小二乘法(least squares ):
可是,最小二乘法在哪实现呢?
光看Api肯定是看不出来的,要深入到源码中去。不过,要找最小二乘法,首先我们得要知道她长什么样。
这个问题有点复杂。准确来说,最小二乘法是一种解法,用来求当均方误差最小时,权重向量w的闭式解。不过好在,我们知道闭式解长这样:
如果用Python来实现,对应的代码应该长这样:
代码语言:javascript复制np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
好了,可以开始按图索骥了。
Api具体文件路径/sklearn/linear_model/_base.py,这是个近600行的大文件,我们要找的LinearRegression类,在不同版本位置略有不同,目前最新的0.22.1版在375行,起头长这样:
LinearRegression类内容也不少,不过大多数都是各种分支判断,一行行看找得太慢。好在我们知道,最小二乘法是线性回归的优化方法,只是在模型的训练阶段时候登场。
对应到Api当中,就是最小二乘法的fit方法了,在467行:
不过,代码还是很长......
没关系,还有办法。根据Api文档,模型的权重向量w,是保存在属性coef_(英文coefficients的缩写,意为“系数”)中:
既然在类中,就找self.coef_的赋值好了。很快定位到532行:
这里出现了X和y,主角都登场了,可是舞台却是numpy的线性代数工具库linalg,为什么没找到想要找的那段代码呢?
因为,这里的lstsq,就是numpy提供的最小二乘法计算工具:
看来scikit-learn选择的是直接调用现成工具,不打算重复造轮子了。如果还不放心,可以用这段代码反复比较一下,w1和w2的值是完全相等的:
代码语言:javascript复制import numpy as np
X =np.random.rand(4,3)
y =np.random.rand(4)
w1=np.linalg.lstsq(X, y)[0]
w2=np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
下回再聊。