运用伪逆矩阵求最小二乘解

2023-04-09 10:27:13 浏览数 (1)

之前分析过最小二乘的理论,记录了 Scipy 库求解的方法,但无法求解多元自变量模型,本文记录更加通用的伪逆矩阵求解最小二乘解的方法。

背景

我已经反复研习很多关于最小二乘的内容,虽然朴素但是着实花了一番功夫:

  • 介绍过最小二乘在线性回归中的公式推导;
  • 分析了最小二乘的来源和其与高斯分布的紧密关系;
  • 学习了伪逆矩阵在最小二乘求解过程中的理论应用;
  • 记录了 Scipy 用于求解最小二乘解的函数;

已经有工具可以解很多最小二乘的模型参数了,但是几个专用的最小二乘方法最多支持一元函数的求解,难以计算多元函数最小二乘解,此时就可以用伪逆矩阵求解了。

多元多项式形式模型

这个概念可能不够准确,我要描述的是形如如下函数的一类模型:

f( {bf x} )=sum _{i=1}^{n}a_if_i(x_i)

其中模型

最小二乘的损失函数为:

L= sum_{i=1}left(fleft(x_{i}right)-y_{i}right){2}

对于上述模型,可以利用伪逆求最小二乘解的方法可以用于求解类似线性多项式形式的模型参数,这样就可以求解多元、更加复杂的模型参数。

  • 本质上来说,就是因为这种形式的模型可以凑出形如 A x=b 的矩阵表示,因此可以用这种方法求解。

伪逆求解

在介绍伪逆的文章中其实已经把理论说完了,这里搬运结论:

  • 方程组 A x=b 的最佳最小二乘解为 x=A^{ } b ,并且最佳最小二乘解是唯一的。

实例应用

Python 求逆矩阵
矩阵求逆
代码语言:javascript复制
import numpy as np

a  = np.array([[1, 2], [3, 4]])  # 初始化一个非奇异矩阵(数组)
print(np.linalg.inv(a))  # 对应于MATLAB中 inv() 函数

# 矩阵对象可以通过 .I 更方便的求逆
A = np.matrix(a)
print(A.I)


-->
[[-2.   1. ]
 [ 1.5 -0.5]]
[[-2.   1. ]
 [ 1.5 -0.5]]

矩阵求伪逆
代码语言:javascript复制
import numpy as np

# 定义一个奇异阵 A
A = np.zeros((4, 4))
A[0, -1] = 1
A[-1, 0] = -1
A = np.matrix(A)
print(A)
# print(A.I)  将报错,矩阵 A 为奇异矩阵,不可逆
print(np.linalg.pinv(A))   # 求矩阵 A 的伪逆(广义逆矩阵),对应于MATLAB中 pinv() 函数


-->
[[ 0.  0.  0.  1.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [-1.  0.  0.  0.]]
[[ 0.  0.  0. -1.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 1.  0.  0.  0.]]

应用示例

假设我们需要拟合一个多元的复杂的但是参数为多项式形式的模型参数,模型为:

f( {bf x} )=p_1frac{e^{x_1}}{sqrt{x_1}} p_2 x_2^{1.5} p_3 sin x_3

模型真实参数为

代码语言:javascript复制
import numpy as np

# 定义函数
def f1(x):
    return (np.e ** x) / (x ** 0.5)

def f2(x):
    return (x ** 1.5)

def f3(x):
    return np.sin(x)

# 真实参数
gt_p = [7, 3, 12]

# 真实模型
def f(x1, x2, x3):
    return gt_p[0] * f1(x1)   gt_p[1] * f2(x2)   gt_p[2] * f3(x3)

# 三组自变量数据
X1 = np.arange(1, 3, 0.1)
X2 = X1 * 3
X3 = X1 ** 2

# 生成带噪声的观测值 b
b = np.matrix(f(X1, X2, X3)   (np.random.rand(len(X1)) - 0.5)).T

# 生成矩阵 A
A0 = f1(X1)
A1 = f2(X2)
A2 = f3(X3)

A = np.matrix(np.vstack([A0, A1, A2]).T)

# 逆矩阵求解
para = np.linalg.pinv(A) * b

# 输出结果
print(f"ground truth: {gt_p}")
print(f"got: {para.tolist()}")

输出结果:

代码语言:javascript复制
ground truth: [7, 3, 12]
got: [[7.046011821943054], [2.9831510054344843], [11.989895579628328]]

参考资料

  • https://cloud.tencent.com/developer/article/2220516
  • https://cloud.tencent.com/developer/article/2066930
  • https://cloud.tencent.com/developer/article/2260272
  • https://www.zywvvd.com/notes/study/linear-algebra/inverse-matrix/gen-inverse-matrix/

文章链接: https://cloud.tencent.com/developer/article/2260562

0 人点赞