线性模型
一般流程
- 准备数据集(训练集,开发集,测试集)
- 选择模型(泛化能力,防止过拟合)
- 训练模型
- 测试模型
例子
学生每周学习时间与期末得分的关系
x(hours) | y(points) |
---|---|
1 | 2 |
2 | 4 |
3 | 6 |
4 | ? |
设计模型
观察数据分布可得应采用线性模型:
其中 hat y 为预测值,不妨简化一下模型为:
我们的目的就是得到一个尽可能好的 w 值。使模型的预测值越 接近 真实值,因此我们需要一个衡量接近程度的指标 loss,可用绝对值或差的平方表示单 g 个样本预测的损失为(Training Loss):
这里使用差的平方,其中 y 为真实值。
因此,对于多样本预测的平均损失函数为(Mean Square Error):
代码语言:javascript复制# 定义模型函数
def forward(x):
return x * w;
# 定义损失函数
def loss(x, y):
y_predict = forward(x)
return (y - y_predict) ** 2
过程模拟
由于不知道 w 的具体值因此我们给它一个随机初始值,假设 w = 3
x(hours) | y(points) | y_predict | loss |
---|---|---|---|
1 | 2 | 3 | 1 |
2 | 4 | 6 | 4 |
3 | 6 | 9 | 9 |
MSE=14/3 |
可知本轮预测平均损失为 14/3
为找到最佳权重,可枚举权重值判断损失,损失最小为最佳
代码语言:javascript复制# 存放枚举到的权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []
# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
print("w=", w)
loss_sum = 0 # 损失和
for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
y_predict_val = forward(x_val) # 计算预测值
loss_val = loss(x_val, y_val) # 计算单样本损失
loss_sum = loss_val # 更新损失和
print('tt',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
print('MSE=',loss_sum / len(x_data))
w_list.append(w)
mse_list.append(loss_sum / len(x_data))
具体实现
代码语言:javascript复制import numpy as np
import matplotlib.pyplot as plt
# 准备数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# 定义模型函数
def forward(x):
return x * w;
# 定义损失函数
def loss(x, y):
y_predict = forward(x)
return (y - y_predict) ** 2
# 权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []
# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
print("w=", w)
loss_sum = 0 # 损失和
for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
y_predict_val = forward(x_val) # 计算预测值
loss_val = loss(x_val, y_val) # 计算单样本损失
loss_sum = loss_val # 更新损失和
print('tt',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
print('MSE=',loss_sum / len(x_data))
w_list.append(w)
mse_list.append(loss_sum / len(x_data))
得到每轮的预测结果
代码语言:javascript复制w= 0.0
1.0 2.0 0.00 4.00
2.0 4.0 0.00 16.00
3.0 6.0 0.00 36.00
MSE= 18.666666666666668
w= 0.1
1.0 2.0 0.10 3.61
2.0 4.0 0.20 14.44
3.0 6.0 0.30 32.49
MSE= 16.846666666666668
w= 0.2
1.0 2.0 0.20 3.24
2.0 4.0 0.40 12.96
3.0 6.0 0.60 29.16
MSE= 15.120000000000003
w= 0.30000000000000004
1.0 2.0 0.30 2.89
2.0 4.0 0.60 11.56
3.0 6.0 0.90 26.01
MSE= 13.486666666666665
w= 0.4
1.0 2.0 0.40 2.56
2.0 4.0 0.80 10.24
3.0 6.0 1.20 23.04
MSE= 11.946666666666667
w= 0.5
1.0 2.0 0.50 2.25
2.0 4.0 1.00 9.00
3.0 6.0 1.50 20.25
MSE= 10.5
w= 0.6000000000000001
1.0 2.0 0.60 1.96
2.0 4.0 1.20 7.84
3.0 6.0 1.80 17.64
MSE= 9.146666666666663
w= 0.7000000000000001
1.0 2.0 0.70 1.69
2.0 4.0 1.40 6.76
3.0 6.0 2.10 15.21
MSE= 7.886666666666666
w= 0.8
1.0 2.0 0.80 1.44
2.0 4.0 1.60 5.76
3.0 6.0 2.40 12.96
MSE= 6.719999999999999
w= 0.9
1.0 2.0 0.90 1.21
2.0 4.0 1.80 4.84
3.0 6.0 2.70 10.89
MSE= 5.646666666666666
w= 1.0
1.0 2.0 1.00 1.00
2.0 4.0 2.00 4.00
3.0 6.0 3.00 9.00
MSE= 4.666666666666667
w= 1.1
1.0 2.0 1.10 0.81
2.0 4.0 2.20 3.24
3.0 6.0 3.30 7.29
MSE= 3.779999999999999
w= 1.2000000000000002
1.0 2.0 1.20 0.64
2.0 4.0 2.40 2.56
3.0 6.0 3.60 5.76
MSE= 2.986666666666665
w= 1.3
1.0 2.0 1.30 0.49
2.0 4.0 2.60 1.96
3.0 6.0 3.90 4.41
MSE= 2.2866666666666657
w= 1.4000000000000001
1.0 2.0 1.40 0.36
2.0 4.0 2.80 1.44
3.0 6.0 4.20 3.24
MSE= 1.6799999999999995
w= 1.5
1.0 2.0 1.50 0.25
2.0 4.0 3.00 1.00
3.0 6.0 4.50 2.25
MSE= 1.1666666666666667
w= 1.6
1.0 2.0 1.60 0.16
2.0 4.0 3.20 0.64
3.0 6.0 4.80 1.44
MSE= 0.746666666666666
w= 1.7000000000000002
1.0 2.0 1.70 0.09
2.0 4.0 3.40 0.36
3.0 6.0 5.10 0.81
MSE= 0.4199999999999995
w= 1.8
1.0 2.0 1.80 0.04
2.0 4.0 3.60 0.16
3.0 6.0 5.40 0.36
MSE= 0.1866666666666665
w= 1.9000000000000001
1.0 2.0 1.90 0.01
2.0 4.0 3.80 0.04
3.0 6.0 5.70 0.09
MSE= 0.046666666666666586
w= 2.0
1.0 2.0 2.00 0.00
2.0 4.0 4.00 0.00
3.0 6.0 6.00 0.00
MSE= 0.0
w= 2.1
1.0 2.0 2.10 0.01
2.0 4.0 4.20 0.04
3.0 6.0 6.30 0.09
MSE= 0.046666666666666835
w= 2.2
1.0 2.0 2.20 0.04
2.0 4.0 4.40 0.16
3.0 6.0 6.60 0.36
MSE= 0.18666666666666698
w= 2.3000000000000003
1.0 2.0 2.30 0.09
2.0 4.0 4.60 0.36
3.0 6.0 6.90 0.81
MSE= 0.42000000000000054
w= 2.4000000000000004
1.0 2.0 2.40 0.16
2.0 4.0 4.80 0.64
3.0 6.0 7.20 1.44
MSE= 0.7466666666666679
w= 2.5
1.0 2.0 2.50 0.25
2.0 4.0 5.00 1.00
3.0 6.0 7.50 2.25
MSE= 1.1666666666666667
w= 2.6
1.0 2.0 2.60 0.36
2.0 4.0 5.20 1.44
3.0 6.0 7.80 3.24
MSE= 1.6800000000000008
w= 2.7
1.0 2.0 2.70 0.49
2.0 4.0 5.40 1.96
3.0 6.0 8.10 4.41
MSE= 2.2866666666666693
w= 2.8000000000000003
1.0 2.0 2.80 0.64
2.0 4.0 5.60 2.56
3.0 6.0 8.40 5.76
MSE= 2.986666666666668
w= 2.9000000000000004
1.0 2.0 2.90 0.81
2.0 4.0 5.80 3.24
3.0 6.0 8.70 7.29
MSE= 3.780000000000003
w= 3.0
1.0 2.0 3.00 1.00
2.0 4.0 6.00 4.00
3.0 6.0 9.00 9.00
MSE= 4.666666666666667
w= 3.1
1.0 2.0 3.10 1.21
2.0 4.0 6.20 4.84
3.0 6.0 9.30 10.89
MSE= 5.646666666666668
w= 3.2
1.0 2.0 3.20 1.44
2.0 4.0 6.40 5.76
3.0 6.0 9.60 12.96
MSE= 6.720000000000003
w= 3.3000000000000003
1.0 2.0 3.30 1.69
2.0 4.0 6.60 6.76
3.0 6.0 9.90 15.21
MSE= 7.886666666666668
w= 3.4000000000000004
1.0 2.0 3.40 1.96
2.0 4.0 6.80 7.84
3.0 6.0 10.20 17.64
MSE= 9.14666666666667
w= 3.5
1.0 2.0 3.50 2.25
2.0 4.0 7.00 9.00
3.0 6.0 10.50 20.25
MSE= 10.5
w= 3.6
1.0 2.0 3.60 2.56
2.0 4.0 7.20 10.24
3.0 6.0 10.80 23.04
MSE= 11.94666666666667
w= 3.7
1.0 2.0 3.70 2.89
2.0 4.0 7.40 11.56
3.0 6.0 11.10 26.01
MSE= 13.486666666666673
w= 3.8000000000000003
1.0 2.0 3.80 3.24
2.0 4.0 7.60 12.96
3.0 6.0 11.40 29.16
MSE= 15.120000000000005
w= 3.9000000000000004
1.0 2.0 3.90 3.61
2.0 4.0 7.80 14.44
3.0 6.0 11.70 32.49
MSE= 16.84666666666667
w= 4.0
1.0 2.0 4.00 4.00
2.0 4.0 8.00 16.00
3.0 6.0 12.00 36.00
MSE= 18.666666666666668
画出权重与平均损失的关系图
代码语言:javascript复制# 绘图(权重与平均损失的关系)
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()
由上图可知,但 w = 2.0 时损失最小,该点也是损失函数图像的最小值。