生成随机矩阵
代码语言:javascript复制x = np.random.rand(2,3)
array([[0.10786477, 0.56611762, 0.10557245], [0.4596513 , 0.13174377, 0.82373043]])
计算softmax
在numpy中
代码语言:javascript复制y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
array([[0.27940617, 0.44182742, 0.27876641], [0.31649398, 0.22801164, 0.45549437]])
在pytorch中
代码语言:javascript复制torch_x = torch.from_numpy(x) torch_y = nn.Softmax(dim=-1)(torch_x)
tensor([[0.2794, 0.4418, 0.2788], [0.3165, 0.2280, 0.4555]], dtype=torch.float64)
计算log_softmax
在numpy中
代码语言:javascript复制import numpy as np
x = np.array([[-0.7715, -0.6205,-0.2562]])
y = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
y = np.log(y)
array([[-1.27508877, -0.81683591, -1.27738109], [-1.15045104, -1.47835858, -0.78637192]])
在pytorch中
代码语言:javascript复制torch_x = torch.from_numpy(x)
torch_y = nn.LogSoftmax(dim=-1)(torch_x)
tensor([[-1.2751, -0.8168, -1.2774], [-1.1505, -1.4784, -0.7864]], dtype=torch.float64)
计算NLLLoss
说明,就是在计算log_softmax之后,根据每个样本的真实标签取得其对应的值。默认权重都是1,而且采取求均值的方式。这里就是-(-1.27508877 -0.78637192) / 2,即取出第0行的第0个和第1行的第2个,正好对应[0, 2]。
在numpy中
代码语言:javascript复制targets = np.array([0, 2])
nll_loss = -(np.sum(np.choose(targets, y.T)) / y.shape[0])
1.0307303437846973
在pytorch中
首先我们来看下官方代码:
代码语言:javascript复制 | >>> m = nn.LogSoftmax(dim=1)
| >>> loss = nn.NLLLoss()
| >>> # input is of size N x C = 3 x 5
| >>> input = torch.randn(3, 5, requires_grad=True)
| >>> # each element in target has to have 0 <= value < C
| >>> target = torch.tensor([1, 0, 4])
| >>> output = loss(m(input), target)
| >>> output.backward()
发现其也是在计算LogSoftmax之后计算NLLLoss()。 我们在看下pytorch的计算结果:
代码语言:javascript复制torch_targets = torch.tensor([0, 2])
torch_nll_loss = nn.NLLLoss()(torch_y, torch_targets)
tensor(1.0307, dtype=torch.float64) 与我们一步一步利用numpy计算的保持一致。 最后我们在利用更直观的一种形式来看看:
代码语言:javascript复制import torch.nn.functional as F
output = F.nll_loss(F.log_softmax(torch_x, dim=1), torch_targets, reduction='mean')
tensor(1.0307, dtype=torch.float64) 结果也符合我们的预期。
https://www.gentlecp.com/articles/874.html https://blog.csdn.net/qq_28418387/article/details/95918829 https://blog.csdn.net/yyhhlancelot/article/details/83142255 https://blog.csdn.net/Jeremy_lf/article/details/102725285