明月深度学习实践006:SENet的升级架构SKNet

2021-10-28 14:20:25 浏览数 (1)

SENet之后,还有一个升级的架构SKNet,先放些相关资料:

  • 论文:https://arxiv.org/pdf/1903.06586.pdf
  • 作者解读:https://zhuanlan.zhihu.com/p/59690223
  • 源码:https://github.com/implus/SKNet (基于caffe,没学习过,看起来比较费劲)
  • 第三方实现:https://github.com/pppLang/SKNet

(说明:以下图像如无特殊说明,均来自作者论文)

1. SKNet介绍


SKNet中的SK是“Selective Kernel”的缩写,作者在知乎的文章提到这个架构设计的思路:

大家知道抛开attention的引入,此前比较plain并且work的两大架构:ResNeXt 和 Inception。前者特点是用group卷积轻量化了>1 kernel的卷积核;后者特点是多路的multiple kernel设计。我们设计的一大出发点就是看是否能够combine两者的特色。

不过我理解时,觉得SKNet是SENet架构的一个扩展版本,集成了多卷积核学习,来提升模型整体效果。

2. SKNet卷积计算


直接看图:

如图所示,输入X,输出V,中间的部分就是SKBlock了。

计算过程,可以分为三个阶段:

2.1 Split分离特征图

其实就是使用不同size的卷积核来分离特征图,如图使用了3*3和5*5两种size,实际上还可以使用更多size的卷积核,如7*7等。注意,这里不同卷积核输出的特征图的shape都是H*W*C,要满足这个并不难:

代码语言:javascript复制
self.convs = nn.ModuleList([])
# M为卷积核的数量
for i in range(M):
    self.convs.append(nn.Sequential(
         nn.Conv2d(C, C, kernel_size=3 i*2, stride=1, padding=1 i, groups=G),
         nn.BatchNorm2d(C),
         nn.ReLU(inplace=False)
    ))

其实只要步长一致,控制好padding的值,就能保证输出的特征图的shape是保持一致的。

将不同卷积核集成学习,能更学习到更多不同感受野的特征,对于提升模型的效果很可能是有效的,这个思想在机器学习深度学习领域可谓屡试不爽。

2.2 Fuse融合

Fuse这个单词不知道该怎么去翻译,估计理解为融合是比较好的吧。

在split阶段,我们已经分离出多个分支的特征图,这时我们如果直接将SENet嫁接过来,是否可以了呢?如图:

关于SENet,可以看前一篇文章。

如果我们再把上图融合一下,就可能变成下面的架构:

这到底有没有效,不知道,因为没有做过实验,不过我估计是有效的,有时间的时候可以尝试尝试。SKNet的整体思想,我觉得跟这个是很类似的,这也是我说SKNet是SENet的升级版本的原因。

我们回到SKNet,Fuse计算也分成几个步骤:

2.2.1 将不同卷积核的特征图融合成一个单一的特征图

这个融合很简单:

2.2.2 将融合后的特征图压缩成1*1*C的张量

这个步骤和SENet中的Squeeze类似:

这样同一通道上的特征图就压缩成了一个实数。

2.2.3 继续将前面的输出压缩成1*1*d

这里的z就是1*1*d的张量,类似SENet中的Excitation的第一个全连接层,所做的运算就是一次全连接 BN ReLU。不过这里的d的取值就不再是通道数除以一个固定值,而是:d = max(C/r, L)。在作者的论文里,L=32。

2.2.4 将1*1*d的张量z分离成两个1*1*C的张量

论文里并没有明确写到这个步骤,但是实际是需要这个步骤,类似SENet中Excitation的第二个全连接层。实现很简单:

代码语言:javascript复制
self.fcs = nn.ModuleList([])
for i in range(M):
    self.fcs.append(nn.Linear(d, C))

虽然使用的是M个结构一样的全连接层,但是在训练过程中,它们的参数矩阵是不一样的,也就是下面的A和B。

很显然,从2.2.2到2.2.4这三个步骤跟SENet几乎是一致的。

2.3 Select

前面我们已经得到了一个1*1*d的张量z,按照SKNet的架构,接下来需要做的是:

  1. 将一个1*1*d的张量z,分离成两个1*1*C的张量a和b。
  2. 将a和b分别和不同卷积核的特征图相乘
  3. 最后将相乘后的两个特征图相加,得到最后的输出特征图V

后面两个步骤比较简单,第一个步骤作者的实现是这样的:

显然a b=1。如果只是看这里,会不清楚A和B这两个张量是哪里来的,其实就是2.2.4中的全连接层的参数矩阵,而Az或者Bz其实就是一次全连接计算。

实现其实就是一次softmax:

代码语言:javascript复制
self.softmax = nn.Softmax(dim=1)

第二第三个步骤很简单:

这样就能得到最后的特征图V。

3. 结合ResNet的应用


嵌入ResNet的结构:

可以看到参数和计算量都是有所增加的。

作者的实验数据还是很不错的。

4. SKNet小结


对于SKNet,我觉得理解它的关键点是:

  • 使用不同的卷积核,集成学习不同感受野的特征;
  • 整合SENet的思想。

写于2020-10-06

0 人点赞