聊一聊损失函数

2022-10-31 17:35:25 浏览数 (1)

聊一聊损失函数

前言

损失函数,具体来说就是计算神经网络每次迭代的前向计算结果与真实值的差距,从而指导下一步的训练向正确的方向进行。下面主要介绍一些常见的损失函数:均方差损失函数交叉熵损失函数

均方差损失函数

均方误差损失(Mean Square Error,MSE)又称为二次损失、L2 损失,常用于回归预测任务中。均方误差函数通过计算预测值和实际值之间距离(即误差)的平方来衡量模型优劣。即预测值和真实值越接近,两者的均方差就越小。

均方差函数常用于线性回归(linear regrWession),即函数拟合(function fitting)。公式如下:

均方差函数比较简单,也较为常见,这里就不多说了。

交叉熵损失函数

交叉熵(Cross Entropy)是 Shannon 信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。在信息论中,交叉熵是表示两个概率分布 p,qp,qp,q 的差异,其中 ppp 表示真实分布,qqq 表示预测分布,那么 H(p,q)H(p,q)H(p,q) 就称为交叉熵:

交叉熵可在神经网络中作为损失函数,ppp 表示真实标记的分布,qqq 则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量 ppp 与 qqq 的相似性。

交叉熵函数常用于逻辑回归(logistic regression),也就是分类(classification)。

信息量

信息量来衡量一个事件的不确定性,一个事件发生的概率越大,不确定性越小,则其携带的信息量就越小。在信息论中,可以通过如下方式表示:

其中xjx_jxj​表示一个事件,p(xj)p(x_j)p(xj​)表示xjx_jxj​发生的概率。

举个例子,对于下面这三个事件,可以通过概率计算其信息量:

事件编号

事件

概率

信息量

x1x_1x1​

优秀

p=0.7p=0.7p=0.7

I=−ln⁡(0.7)=0.36I=-ln(0.7)=0.36I=−ln(0.7)=0.36

x2x_2x2​

及格

p=0.2p=0.2p=0.2

I=−ln⁡(0.2)=1.61I=-ln(0.2)=1.61I=−ln(0.2)=1.61

x3x_3x3​

不及格

p=0.1p=0.1p=0.1

I=−ln⁡(0.1)=2.30I=-ln(0.1)=2.30I=−ln(0.1)=2.30

事件发生的概率越小,其信息量越大。

熵用来衡量一个系统的混乱程度,代表系统中信息量的总和;熵值越大,表明这个系统的不确定性就越大。具体来说:

其中p(xj)p(x_j)p(xj​)表示xjx_jxj​发生的概率,−ln⁡(p(xj))-ln(p(x_j))−ln(p(xj​))表示事件的信息量。

信息量是衡量某个事件的不确定性,而熵是衡量一个系统(所有事件)的不确定性。

对于上面的例子,我们可以计算其熵:

相对熵(KL 散度)

相对熵也称为 KL 散度(Kullback-Leibler divergence),表示同一个随机变量的两个不同分布间的距离,相当于信息论范畴的均方差。

设p(x),q(x)p(x),q(x)p(x),q(x)分别是随机变量xxx的两个概率分布,则ppp对qqq的相对熵计算如下:

其中nnn为事件的所有可能性。相对熵DDD的值越小,表示两个分布越接近。在实际应用中,假如p(x)p(x)p(x)是目标真实的分布,而q(x)q(x)q(x)是预测得来的分布,为了让这两个分布尽可能的相同的,就需要最小化 KL 散度。

交叉熵

将上述公式变形:

其中,等式的前一部分就是ppp的熵,后一部分就是交叉熵:

在机器学习中,我们需要评估标签值yyy和预测值aaa之间的差距,就可以计算DKL(p∥q)D_{KL}(p Vert q)DKL​(p∥q),由于H(y)H(y)H(y)不变,因此在优化过程中只需要考虑交叉熵即可。对于单样本计算如下:

对于批量样本的交叉熵计算如下:

其中mmm为样本数,nnn为分类数。

二分类问题交叉熵

在二分的情况下,通常使用sigmoid 将输出映射为正样本的概率,对于每个类别我们的预测的到的概率为aaa和1−a1-a1−a,所以交叉熵可以简化为:

二分类对于批量样本的交叉熵计算公式:

简单分析一下公式,可以发现,当y=1y=1y=1时为正样本,loss=−ln⁡(a)loss=-ln(a)loss=−ln(a);当y=0y=0y=0时为负样本,

事件编号

预测值aaa

真实值yyy

x1x_1x1​

0.6

1

x2x_2x2​

0.7

1

举个例子,对于上面的情况,我们分别计算其交叉熵损失:

计算得到

​,相应的loss2loss_2loss2​反向传播的力度也会小。

多分类问题交叉熵

多分类问题也是类似的,考虑下面的优秀、及格、不及格分类:

事件编号

p(x1)=优秀p(x_1)=优秀p(x1​)=优秀

p(x1)=及格p(x_1)=及格p(x1​)=及格

p(x1)=不及格p(x_1)=不及格p(x1​)=不及格

真实值yyy

x1x_1x1​

0.2

0.5

0.3

不及格

x2x_2x2​

0.2

0.2

0.6

不及格

举个例子,对于上面的情况,我们分别计算其交叉熵损失:

计算得到

,相应的loss2loss_2loss2​反向传播的力度也会小。

PyTorch 实现

在 PyTorch 中,常用的损失函数我们可以直接调用:

  • nn.MSELoss()
  • nn.CrossEntropyLoss()

但有时我们会需要自定义损失函数,这时我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类。

代码语言:javascript复制
import torch
import torch.nn as nn

class myLoss(nn.Module):
    def __init__(self,parameters)
        self.params = self.parameters

    def forward(self)
        loss = cal_loss(self.params)
        return loss

参考资料

  • 一文搞懂交叉熵损失
  • Ai-edu

0 人点赞