SENet之后,还有一个升级的架构SKNet,先放些相关资料:
- 论文:https://arxiv.org/pdf/1903.06586.pdf
- 作者解读:https://zhuanlan.zhihu.com/p/59690223
- 源码:https://github.com/implus/SKNet (基于caffe,没学习过,看起来比较费劲)
- 第三方实现:https://github.com/pppLang/SKNet
(说明:以下图像如无特殊说明,均来自作者论文)
1. SKNet介绍
SKNet中的SK是“Selective Kernel”的缩写,作者在知乎的文章提到这个架构设计的思路:
大家知道抛开attention的引入,此前比较plain并且work的两大架构:ResNeXt 和 Inception。前者特点是用group卷积轻量化了>1 kernel的卷积核;后者特点是多路的multiple kernel设计。我们设计的一大出发点就是看是否能够combine两者的特色。
不过我理解时,觉得SKNet是SENet架构的一个扩展版本,集成了多卷积核学习,来提升模型整体效果。
2. SKNet卷积计算
直接看图:
如图所示,输入X,输出V,中间的部分就是SKBlock了。
计算过程,可以分为三个阶段:
2.1 Split分离特征图
其实就是使用不同size的卷积核来分离特征图,如图使用了3*3和5*5两种size,实际上还可以使用更多size的卷积核,如7*7等。注意,这里不同卷积核输出的特征图的shape都是H*W*C,要满足这个并不难:
代码语言:javascript复制self.convs = nn.ModuleList([])
# M为卷积核的数量
for i in range(M):
self.convs.append(nn.Sequential(
nn.Conv2d(C, C, kernel_size=3 i*2, stride=1, padding=1 i, groups=G),
nn.BatchNorm2d(C),
nn.ReLU(inplace=False)
))
其实只要步长一致,控制好padding的值,就能保证输出的特征图的shape是保持一致的。
将不同卷积核集成学习,能更学习到更多不同感受野的特征,对于提升模型的效果很可能是有效的,这个思想在机器学习深度学习领域可谓屡试不爽。
2.2 Fuse融合
Fuse这个单词不知道该怎么去翻译,估计理解为融合是比较好的吧。
在split阶段,我们已经分离出多个分支的特征图,这时我们如果直接将SENet嫁接过来,是否可以了呢?如图:
关于SENet,可以看前一篇文章。
如果我们再把上图融合一下,就可能变成下面的架构:
这到底有没有效,不知道,因为没有做过实验,不过我估计是有效的,有时间的时候可以尝试尝试。SKNet的整体思想,我觉得跟这个是很类似的,这也是我说SKNet是SENet的升级版本的原因。
我们回到SKNet,Fuse计算也分成几个步骤:
2.2.1 将不同卷积核的特征图融合成一个单一的特征图
这个融合很简单:
2.2.2 将融合后的特征图压缩成1*1*C的张量
这个步骤和SENet中的Squeeze类似:
这样同一通道上的特征图就压缩成了一个实数。
2.2.3 继续将前面的输出压缩成1*1*d
这里的z就是1*1*d的张量,类似SENet中的Excitation的第一个全连接层,所做的运算就是一次全连接 BN ReLU。不过这里的d的取值就不再是通道数除以一个固定值,而是:d = max(C/r, L)。在作者的论文里,L=32。
2.2.4 将1*1*d的张量z分离成两个1*1*C的张量
论文里并没有明确写到这个步骤,但是实际是需要这个步骤,类似SENet中Excitation的第二个全连接层。实现很简单:
代码语言:javascript复制self.fcs = nn.ModuleList([])
for i in range(M):
self.fcs.append(nn.Linear(d, C))
虽然使用的是M个结构一样的全连接层,但是在训练过程中,它们的参数矩阵是不一样的,也就是下面的A和B。
很显然,从2.2.2到2.2.4这三个步骤跟SENet几乎是一致的。
2.3 Select
前面我们已经得到了一个1*1*d的张量z,按照SKNet的架构,接下来需要做的是:
- 将一个1*1*d的张量z,分离成两个1*1*C的张量a和b。
- 将a和b分别和不同卷积核的特征图相乘
- 最后将相乘后的两个特征图相加,得到最后的输出特征图V
后面两个步骤比较简单,第一个步骤作者的实现是这样的:
显然a b=1。如果只是看这里,会不清楚A和B这两个张量是哪里来的,其实就是2.2.4中的全连接层的参数矩阵,而Az或者Bz其实就是一次全连接计算。
实现其实就是一次softmax:
代码语言:javascript复制self.softmax = nn.Softmax(dim=1)
第二第三个步骤很简单:
这样就能得到最后的特征图V。
3. 结合ResNet的应用
嵌入ResNet的结构:
可以看到参数和计算量都是有所增加的。
作者的实验数据还是很不错的。
4. SKNet小结
对于SKNet,我觉得理解它的关键点是:
- 使用不同的卷积核,集成学习不同感受野的特征;
- 整合SENet的思想。
写于2020-10-06