【动手学深度学习笔记】之线性回归

2020-08-20 07:51:48 浏览数 (1)


第一种模型:线性回归

线性回归适用于回归问题,回归问题主要包括:房屋价格预测,温度等等。

线性回归是单层神经网络,设计的概念和技术适用于大多数深度学习模型;因此,我们以线性回归为例,学习深度学习模型的基本要素和表示方法。

线性回归的基本要素

以预测房屋价格预测为例:目标是预测一栋房屋的售出价格;简单起见,假设之取决于两个因素,即面积(平方米)和房龄(年)

模型定义

设房屋面积X1,房龄X2,售出价格y。建立输入X1和X2来计算输出y的表达式。

线性关系为:

其中 W1和W2是权重,b是偏差,这三个变量都是标量,他们是线性回归模型的参数。模型的输出y是对真实值的预测或估计,与真实值存在一定误差。

模型训练

需要通过数据来寻找特定的模型参数值(W1,W2,b),使模型在数据上的误差尽可能小。这个搜索的过程叫做模型训练。

以下是模型训练所设计的3个要素。

(1)训练数据

通常收集一系列的真实数据(数据集),例如:多栋房屋的真实收储价格和它们对应的面积和房龄。目的是希望在这个数据上面寻找模型参数来使模型的预测价格与真实价格的误差最小

术语中,这套数据集叫做训练数据集训练集,一栋房屋被称作一个样本,其真实售出价格(y)叫做标签(label),用来预测的两个因素叫做特征(feature)。特征用来表征样本的特点(例如:房龄和面积)

假设我们采集的样本数为n,对于索引为i的房屋。线性回归模型的房屋价格预测表达式为

(2)损失函数

深度学习过程中需要衡量价格预测值与真实值之间的误差。通常会选取一个非负数作为误差,且数值越小表示误差越小。一个常用的选择是平方函数。它在索引为i的样本误差的表达式为

这个式子其实很简单,只是看着唬人而已,右边括号中的:通过目前参数(w_1,w_2,b)计算得到的估计值hat{y}减去真实价格,为了让这个值为正数,对他进行平方是为了使求导后的系数为1,所以乘其中常数1/2使这个表达式对平方项求导后常数系数为1。

这个衡量误差的函数称为损失函数(loss function)。这个使用平方误差函数也称为平方损失。

通常我们用训练数据集中所有样本误差的平均来衡量模型预测的质量,即

这个式子和上面的式子一样,上面的是求一个值的误差,这个式子是一组值的平均误差. 在训练模型中,希望找出一组模型参数满足下式(样本平均损失最小),这组参数就是训练模型的目标解

(3)优化算法

解析解:直接通过模型和损失函数能够得到的解。

数值解:通过优化算法有限次迭代来尽可能降低损失函数的值的解。

优化算法:这本书介绍的是小批量随机梯度下降

  1. 随机选取一组初始值
  2. 多次迭代,使每一次迭代的值都更小一点

迭代过程

假定这一个小批量有B个数,学习率为eta。(这两个是人为设定的超参数)

每次迭代对三个参数进行分别计算。

  1. 首先是对损失函数进行求导
  2. 然后求小批量对应的损失函数的导数的平均值,最后乘以学习率,得到减小量
  3. 根据减小量得到迭代后的w1、w2、b

模型预测

确定训练完成后,将模型参数再优化算法停止时的值分别记作

这几个参数并不一定是最优解,而是一个最优解的近似,这样就可以使用线性回归模型

来估算数据训练集以外的任意一栋面积为x1和房龄为x2的房子的价格了。

线性回归的表示方法

这节解释线性回归和神经网络的联系以及线性回归的矢量表达式。

神经网络图

深度学习中,我们使用神经网络图直观地表现模型结构。

该图使用神经网络图表示本届中介绍的线性回归网络,倒数隐去了权重和偏差。

输入个数也叫做特征数或特征向量维度。由该图可以看出,线性回归是一个单层神经网络。

输出层中负责计算o的单元叫神经元,这个模型中输出层的神经元和输入层中各个输入完全链接,因此这个输入层又叫全连接层或稠密层。

矢量计算表达式

在训练或使用模型的过程中,需要对多个样本进行操作。操作的方法是矢量计算(向量运算)。

使用矢量计算:

代码语言:javascript复制
#首先定义两个一千纬的向量import torch from time import time  #导入这个库的目的是比较两种方法哪个更快
a = torch.ones(1000)b = torch.ones(1000)
#相加的第一种方式---标量计算start = time()      #开始计时c = torch.zeros(1000)for i in range(1000):    c[i] = a[i]   b[i]print(time()-start)    #输出运行时间1
#相加的第二种方式---矢量计算start = time()d = a   bprint(time()-start)    #输出运行时间2
#输出结果:#方法一耗时0.02039s#方法二耗时0.0008330s

0 人点赞