【Pytorch基础】线性模型

2023-02-27 16:58:15 浏览数 (1)

线性模型

一般流程

  1. 准备数据集(训练集,开发集,测试集)
  2. 选择模型(泛化能力,防止过拟合)
  3. 训练模型
  4. 测试模型

例子

学生每周学习时间与期末得分的关系

x(hours)

y(points)

1

2

2

4

3

6

4

?

设计模型

观察数据分布可得应采用线性模型:

hat y = x * w b

其中 hat y 为预测值,不妨简化一下模型为:

hat y = x* w

我们的目的就是得到一个尽可能好的 w 值。使模型的预测值越 接近 真实值,因此我们需要一个衡量接近程度的指标 loss,可用绝对值或差的平方表示单 g 个样本预测的损失为(Training Loss):

loos = (hat y - y)^2 = (x*w - y)^2 geq 0

这里使用差的平方,其中 y 为真实值。

因此,对于多样本预测的平均损失函数为(Mean Square Error):

MSE = frac{sum_{i=0}^{n}(hat y_i - y_i)^2}{n}
代码语言:javascript复制
# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

过程模拟

由于不知道 w 的具体值因此我们给它一个随机初始值,假设 w = 3

x(hours)

y(points)

y_predict

loss

1

2

3

1

2

4

6

4

3

6

9

9

MSE=14/3

可知本轮预测平均损失为 14/3

为找到最佳权重,可枚举权重值判断损失,损失最小为最佳

代码语言:javascript复制
# 存放枚举到的权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum  = loss_val # 更新损失和
        print('tt',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

具体实现

代码语言:javascript复制
import numpy as np
import matplotlib.pyplot as plt

# 准备数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

# 权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum  = loss_val # 更新损失和
        print('tt',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

得到每轮的预测结果

代码语言:javascript复制
w= 0.0
         1.0 2.0 0.00 4.00
         2.0 4.0 0.00 16.00
         3.0 6.0 0.00 36.00
MSE= 18.666666666666668
w= 0.1
         1.0 2.0 0.10 3.61
         2.0 4.0 0.20 14.44
         3.0 6.0 0.30 32.49
MSE= 16.846666666666668
w= 0.2
         1.0 2.0 0.20 3.24
         2.0 4.0 0.40 12.96
         3.0 6.0 0.60 29.16
MSE= 15.120000000000003
w= 0.30000000000000004
         1.0 2.0 0.30 2.89
         2.0 4.0 0.60 11.56
         3.0 6.0 0.90 26.01
MSE= 13.486666666666665
w= 0.4
         1.0 2.0 0.40 2.56
         2.0 4.0 0.80 10.24
         3.0 6.0 1.20 23.04
MSE= 11.946666666666667
w= 0.5
         1.0 2.0 0.50 2.25
         2.0 4.0 1.00 9.00
         3.0 6.0 1.50 20.25
MSE= 10.5
w= 0.6000000000000001
         1.0 2.0 0.60 1.96
         2.0 4.0 1.20 7.84
         3.0 6.0 1.80 17.64
MSE= 9.146666666666663
w= 0.7000000000000001
         1.0 2.0 0.70 1.69
         2.0 4.0 1.40 6.76
         3.0 6.0 2.10 15.21
MSE= 7.886666666666666
w= 0.8
         1.0 2.0 0.80 1.44
         2.0 4.0 1.60 5.76
         3.0 6.0 2.40 12.96
MSE= 6.719999999999999
w= 0.9
         1.0 2.0 0.90 1.21
         2.0 4.0 1.80 4.84
         3.0 6.0 2.70 10.89
MSE= 5.646666666666666
w= 1.0
         1.0 2.0 1.00 1.00
         2.0 4.0 2.00 4.00
         3.0 6.0 3.00 9.00
MSE= 4.666666666666667
w= 1.1
         1.0 2.0 1.10 0.81
         2.0 4.0 2.20 3.24
         3.0 6.0 3.30 7.29
MSE= 3.779999999999999
w= 1.2000000000000002
         1.0 2.0 1.20 0.64
         2.0 4.0 2.40 2.56
         3.0 6.0 3.60 5.76
MSE= 2.986666666666665
w= 1.3
         1.0 2.0 1.30 0.49
         2.0 4.0 2.60 1.96
         3.0 6.0 3.90 4.41
MSE= 2.2866666666666657
w= 1.4000000000000001
         1.0 2.0 1.40 0.36
         2.0 4.0 2.80 1.44
         3.0 6.0 4.20 3.24
MSE= 1.6799999999999995
w= 1.5
         1.0 2.0 1.50 0.25
         2.0 4.0 3.00 1.00
         3.0 6.0 4.50 2.25
MSE= 1.1666666666666667
w= 1.6
         1.0 2.0 1.60 0.16
         2.0 4.0 3.20 0.64
         3.0 6.0 4.80 1.44
MSE= 0.746666666666666
w= 1.7000000000000002
         1.0 2.0 1.70 0.09
         2.0 4.0 3.40 0.36
         3.0 6.0 5.10 0.81
MSE= 0.4199999999999995
w= 1.8
         1.0 2.0 1.80 0.04
         2.0 4.0 3.60 0.16
         3.0 6.0 5.40 0.36
MSE= 0.1866666666666665
w= 1.9000000000000001
         1.0 2.0 1.90 0.01
         2.0 4.0 3.80 0.04
         3.0 6.0 5.70 0.09
MSE= 0.046666666666666586
w= 2.0
         1.0 2.0 2.00 0.00
         2.0 4.0 4.00 0.00
         3.0 6.0 6.00 0.00
MSE= 0.0
w= 2.1
         1.0 2.0 2.10 0.01
         2.0 4.0 4.20 0.04
         3.0 6.0 6.30 0.09
MSE= 0.046666666666666835
w= 2.2
         1.0 2.0 2.20 0.04
         2.0 4.0 4.40 0.16
         3.0 6.0 6.60 0.36
MSE= 0.18666666666666698
w= 2.3000000000000003
         1.0 2.0 2.30 0.09
         2.0 4.0 4.60 0.36
         3.0 6.0 6.90 0.81
MSE= 0.42000000000000054
w= 2.4000000000000004
         1.0 2.0 2.40 0.16
         2.0 4.0 4.80 0.64
         3.0 6.0 7.20 1.44
MSE= 0.7466666666666679
w= 2.5
         1.0 2.0 2.50 0.25
         2.0 4.0 5.00 1.00
         3.0 6.0 7.50 2.25
MSE= 1.1666666666666667
w= 2.6
         1.0 2.0 2.60 0.36
         2.0 4.0 5.20 1.44
         3.0 6.0 7.80 3.24
MSE= 1.6800000000000008
w= 2.7
         1.0 2.0 2.70 0.49
         2.0 4.0 5.40 1.96
         3.0 6.0 8.10 4.41
MSE= 2.2866666666666693
w= 2.8000000000000003
         1.0 2.0 2.80 0.64
         2.0 4.0 5.60 2.56
         3.0 6.0 8.40 5.76
MSE= 2.986666666666668
w= 2.9000000000000004
         1.0 2.0 2.90 0.81
         2.0 4.0 5.80 3.24
         3.0 6.0 8.70 7.29
MSE= 3.780000000000003
w= 3.0
         1.0 2.0 3.00 1.00
         2.0 4.0 6.00 4.00
         3.0 6.0 9.00 9.00
MSE= 4.666666666666667
w= 3.1
         1.0 2.0 3.10 1.21
         2.0 4.0 6.20 4.84
         3.0 6.0 9.30 10.89
MSE= 5.646666666666668
w= 3.2
         1.0 2.0 3.20 1.44
         2.0 4.0 6.40 5.76
         3.0 6.0 9.60 12.96
MSE= 6.720000000000003
w= 3.3000000000000003
         1.0 2.0 3.30 1.69
         2.0 4.0 6.60 6.76
         3.0 6.0 9.90 15.21
MSE= 7.886666666666668
w= 3.4000000000000004
         1.0 2.0 3.40 1.96
         2.0 4.0 6.80 7.84
         3.0 6.0 10.20 17.64
MSE= 9.14666666666667
w= 3.5
         1.0 2.0 3.50 2.25
         2.0 4.0 7.00 9.00
         3.0 6.0 10.50 20.25
MSE= 10.5
w= 3.6
         1.0 2.0 3.60 2.56
         2.0 4.0 7.20 10.24
         3.0 6.0 10.80 23.04
MSE= 11.94666666666667
w= 3.7
         1.0 2.0 3.70 2.89
         2.0 4.0 7.40 11.56
         3.0 6.0 11.10 26.01
MSE= 13.486666666666673
w= 3.8000000000000003
         1.0 2.0 3.80 3.24
         2.0 4.0 7.60 12.96
         3.0 6.0 11.40 29.16
MSE= 15.120000000000005
w= 3.9000000000000004
         1.0 2.0 3.90 3.61
         2.0 4.0 7.80 14.44
         3.0 6.0 11.70 32.49
MSE= 16.84666666666667
w= 4.0
         1.0 2.0 4.00 4.00
         2.0 4.0 8.00 16.00
         3.0 6.0 12.00 36.00
MSE= 18.666666666666668

画出权重与平均损失的关系图

代码语言:javascript复制
# 绘图(权重与平均损失的关系)
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()

由上图可知,但 w = 2.0 时损失最小,该点也是损失函数图像的最小值。

0 人点赞