pytorch中的nn.CrossEntropyLoss()计算原理

2021-07-21 14:47:51 浏览数 (1)

生成随机矩阵

代码语言:javascript复制
x = np.random.rand(2,3) 

array([[0.10786477, 0.56611762, 0.10557245], [0.4596513 , 0.13174377, 0.82373043]])

计算softmax

在numpy中

代码语言:javascript复制
y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)

array([[0.27940617, 0.44182742, 0.27876641], [0.31649398, 0.22801164, 0.45549437]])

在pytorch中

代码语言:javascript复制
torch_x = torch.from_numpy(x) torch_y = nn.Softmax(dim=-1)(torch_x)

tensor([[0.2794, 0.4418, 0.2788], [0.3165, 0.2280, 0.4555]], dtype=torch.float64)

计算log_softmax

在numpy中

代码语言:javascript复制
import numpy as np 
x = np.array([[-0.7715, -0.6205,-0.2562]]) 
y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) 
y = np.log(y) 

array([[-1.27508877, -0.81683591, -1.27738109], [-1.15045104, -1.47835858, -0.78637192]])

在pytorch中

代码语言:javascript复制
torch_x = torch.from_numpy(x) 
torch_y = nn.LogSoftmax(dim=-1)(torch_x) 

tensor([[-1.2751, -0.8168, -1.2774], [-1.1505, -1.4784, -0.7864]], dtype=torch.float64)

计算NLLLoss

说明,就是在计算log_softmax之后,根据每个样本的真实标签取得其对应的值。默认权重都是1,而且采取求均值的方式。这里就是-(-1.27508877 -0.78637192) / 2,即取出第0行的第0个和第1行的第2个,正好对应[0, 2]。

在numpy中

代码语言:javascript复制
targets = np.array([0, 2]) 
nll_loss = -(np.sum(np.choose(targets, y.T)) / y.shape[0]) 

1.0307303437846973

在pytorch中

首先我们来看下官方代码:

代码语言:javascript复制
 |      >>> m = nn.LogSoftmax(dim=1)
 |      >>> loss = nn.NLLLoss()
 |      >>> # input is of size N x C = 3 x 5
 |      >>> input = torch.randn(3, 5, requires_grad=True)
 |      >>> # each element in target has to have 0 <= value < C
 |      >>> target = torch.tensor([1, 0, 4])
 |      >>> output = loss(m(input), target)
 |      >>> output.backward()

发现其也是在计算LogSoftmax之后计算NLLLoss()。 我们在看下pytorch的计算结果:

代码语言:javascript复制
torch_targets = torch.tensor([0, 2])
torch_nll_loss = nn.NLLLoss()(torch_y, torch_targets)

tensor(1.0307, dtype=torch.float64) 与我们一步一步利用numpy计算的保持一致。 最后我们在利用更直观的一种形式来看看:

代码语言:javascript复制
import torch.nn.functional as F 
output = F.nll_loss(F.log_softmax(torch_x, dim=1), torch_targets, reduction='mean')

tensor(1.0307, dtype=torch.float64) 结果也符合我们的预期。

https://www.gentlecp.com/articles/874.html https://blog.csdn.net/qq_28418387/article/details/95918829 https://blog.csdn.net/yyhhlancelot/article/details/83142255 https://blog.csdn.net/Jeremy_lf/article/details/102725285

0 人点赞