本文源自于SPACES:“抽取-生成”式长文本摘要(法研杯总结),原文其实是对一个比赛的总结,里面提到了很多Trick,其中有一个叫做稀疏Softmax(Sparse Softmax)的东西吸引了我的注意,查阅了很多资料以后,汇总在此
Sparse Softmax的思想源于《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》、《Sparse Sequence-to-Sequence Models》等文章。里边作者提出了将Softmax稀疏化的做法来增强其解释性乃至提升效果
不够稀疏的Softmax
前面提到Sparse Softmax本质上是将Softmax的结果稀疏化,那么为什么稀疏化之后会有效呢?我们认稀疏化可以避免Softmax过度学习的问题。假设已经成功分类,那么我们有s_{text{max}}=s_t(目标类别的分数最大),此时我们可以推导原始交叉熵的一个不等式:
$$ begin{aligned} log (sum_{i=1}^n e^{s_i})-s_{text{max}} &= log (e^{s_t} sum_{ineq t}e^{s_i})-s_{text{max}}\ &= log (e^{s_{text{max}}} sum_{ineq t}e^{s_i})-log (e^{s_{text{max}}})\ &= log (frac{e^{s_{text{max}}} sum_{ineq t}e^{s_i}}{e^{s_{text{max}}}})\ &= log (1 sum_{i neq t}e^{s_i - s_{text{max}}})\ & ge log (1 (n - 1)e^{s_{text{min}}-s_{text{max}}}) end{aligned}tag{1} $$
假设当前交叉熵值为varepsilon,那么有
解得
我们以varepsilon = ln2 = 0.69...为例,这时候log (e^{varepsilon} - 1)=0,那么s_{text{max}} - s_{text{min}}ge log (n-1)。也就是说,为了要loss降到0.69,那么最大的logit与最小的logit的差就必须大于log (n-1),当n比较大时,对于分类问题来说这是一个没有必要的过大的间隔,因为我们只希望目标类的logit比所有非目标类都要大一点就行,但是并不一定需要大log (n-1)那么多,因此常规的交叉熵容易过度学习从而导致过拟合
稀疏的Sparsemax
前面说了这么多关于Softmax的内容,那么Sparse Softmax或者说Sparsemax是如何做到稀疏化分布的呢?原文内容大家可以直接去看论文,写的非常复杂,这里我给出苏剑林大佬设计的一个更简单的版本
$$ begin{array}{c|c|c} hline & text{Origin} & text{Sparse} \ hline text{Softmax} & p_i = frac{e^{s_i}}{sumlimits_{j=1}^{n} e^{s_j}} & p_i=left{begin{aligned}&frac{e^{s_i}}{sumlimits_{jinOmega_k} e^{s_j}},,iinOmega_k\ &quad 0,,inotinOmega_kend{aligned}right.\ hline text{CrossEntropy} & logleft(sumlimits_{i=1}^n e^{s_i}right) - s_t & logleft(sumlimits_{iinOmega_k} e^{s_i}right) - s_t\ hline end{array} $$
其中Omega_k是将s_1,s_2,...,s_n从大到小排列后前k个元素的下标集合。说白了,苏剑林大佬提出的Sparse Softmax就是在计算概率的时候,只保留前k个,后面的直接置零,k是人为选择的超参数
代码
首先我根据苏剑林大佬的思路,给出一个简单版本的PyTorch代码
代码语言:javascript复制import torch
import torch.nn as nn
class Sparsemax(nn.Module):
"""Sparsemax loss"""
def __init__(self, k_sparse=1):
super(Sparsemax, self).__init__()
self.k_sparse = k_sparse
def forward(self, preds, labels):
"""
Args:
preds (torch.Tensor): [batch_size, number_of_logits]
labels (torch.Tensor): [batch_size] index, not ont-hot
Returns:
torch.Tensor
"""
preds = preds.reshape(preds.size(0), -1) # [batch_size, -1]
topk = preds.topk(self.k_sparse, dim=1)[0] # [batch_size, k_sparse]
# log(sum(exp(topk)))
pos_loss = torch.logsumexp(topk, dim=1)
# s_t
neg_loss = torch.gather(preds, 1, labels[:, None].expand(-1, preds.size(1)))[:, 0]
return (pos_loss - neg_loss).sum()
再给出一个Github上找到的一个PyTorch原版代码
代码语言:javascript复制"""Sparsemax activation function.
Pytorch implementation of Sparsemax function from:
-- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
-- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
"""
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Sparsemax(nn.Module):
"""Sparsemax function."""
def __init__(self, dim=None):
"""Initialize sparsemax activation
Args:
dim (int, optional): The dimension over which to apply the sparsemax function.
"""
super(Sparsemax, self).__init__()
self.dim = -1 if dim is None else dim
def forward(self, input):
"""Forward function.
Args:
input (torch.Tensor): Input tensor. First dimension should be the batch size
Returns:
torch.Tensor: [batch_size x number_of_logits] Output tensor
"""
# Sparsemax currently only handles 2-dim tensors,
# so we reshape to a convenient shape and reshape back after sparsemax
input = input.transpose(0, self.dim)
original_size = input.size()
input = input.reshape(input.size(0), -1)
input = input.transpose(0, 1)
dim = 1
number_of_logits = input.size(dim)
# Translate input by max for numerical stability
input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
# Sort input in descending order.
# (NOTE: Can be replaced with linear time selection method described here:
# http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
zs = torch.sort(input=input, dim=dim, descending=True)[0]
range = torch.arange(start=1, end=number_of_logits 1, step=1, device=device, dtype=input.dtype).view(1, -1)
range = range.expand_as(zs)
# Determine sparsity of projection
bound = 1 range * zs
cumulative_sum_zs = torch.cumsum(zs, dim)
is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
k = torch.max(is_gt * range, dim, keepdim=True)[0]
# Compute threshold function
zs_sparse = is_gt * zs
# Compute taus
taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
taus = taus.expand_as(input)
# Sparsemax
self.output = torch.max(torch.zeros_like(input), input - taus)
# Reshape back to original shape
output = self.output
output = output.transpose(0, 1)
output = output.reshape(original_size)
output = output.transpose(0, self.dim)
return output
def backward(self, grad_output):
"""Backward function."""
dim = 1
nonzeros = torch.ne(self.output, 0)
sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
return self.grad_input
*补充
经过苏剑林大佬的许多实验发现,Sparse Softmax只适用于有预训练的场景,因为预训练模型已经训练得很充分了,因此finetune阶段要防止过拟合;但是如果从零训练一个模型,那么Sparse Softmax会造成性能下降,因为每次只有k个类别被学习到,反而会存在学习不充分的情况(欠拟合)
References
- SPACES:“抽取-生成”式长文本摘要(法研杯总结)
- 稀疏序列到序列模型
- 深度学习激活函数从Softmax到Sparsemax
- GLU, sparsemax, GELU激活函数