YOLOv8改进:全网原创首发 | 多尺度空洞注意力(MSDA) | 中科院一区顶刊 DilateFormer 2023.9

2023-10-20 08:59:17 浏览数 (1)

本文全网首发独家改进多尺度空洞注意力(MSDA)采用多头的设计,在不同的头部使用不同的空洞率执行滑动窗口膨胀注意力(SWDA),全网独家首发,创新力度十足,适合科研

1)与C2f结合;2)作为注意力MSDA使用;

1.DilateFormer介绍

本文提出了一种新颖的多尺度空洞 Transformer,简称DilateFormer,以用于视觉识别任务。原有的 ViT 模型在计算复杂性和感受野大小之间的权衡上存在矛盾。众所周知,ViT 模型使用全局注意力机制,能够在任意图像块之间建立长远距离上下文依赖关系,但是全局感受野带来的是平方级别的计算代价。同时,有些研究表明,在浅层特征上,直接进行全局依赖性建模可能存在冗余,因此是没必要的。

为了克服这些问题,作者提出了一种新的注意力机制——多尺度空洞注意力(MSDA)。MSDA 能够模拟小范围内的局部和稀疏的图像块交互,这些发现源自于对 ViTs 在浅层次上全局注意力中图像块交互的分析。作者发现在浅层次上,注意力矩阵具有局部性稀疏性两个关键属性,这表明在浅层次的语义建模中,远离查询块的块大部分无关,因此全局注意力模块中存在大量的冗余。

DilateFormer 是一个以金字塔结构为基础的深度学习模型,它主要设计用来处理基础的视觉任务。DilateFormer 的关键设计概念是利用多尺度空洞注意力(Multi-Scale Dilated Attention, MSDA)来有效捕捉多尺度的语义信息,并减少自注意力机制的冗余。

如下图所示,MSDA 模块同样采用多头的设计,将特征图的通道分为 n 个不同的头部,并在不同的头部使用不同的空洞率执行滑动窗口膨胀注意力(SWDA)。这样可以在被关注的感受野内的各个尺度上聚合语义信息,并有效地减少自注意力机制的冗余,无需复杂的操作和额外的计算成本。

总体来说,DilateFormer 通过这种混合使用多尺度空洞注意力和多头自注意力的方式,成功地处理了长距离依赖问题,同时保持了计算效率,并能够适应不同尺度和分辨率的输入。

2.MSDA引入到YOLOv8

2.1 MSDA加入ultralytics/nn/attention/dilateformer.py

代码语言:javascript复制
class MultiDilatelocalAttention(nn.Module):
    "Implementation of Dilate-attention"

    def __init__(self, dim, num_heads=4, qkv_bias=False, qk_scale=None,
                 attn_drop=0.,proj_drop=0., kernel_size=3, dilation=[1, 2]):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.scale = qk_scale or head_dim ** -0.5
        self.num_dilation = len(dilation)

        assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"
        self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
        self.dilate_attention = nn.ModuleList(
            [DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])
             for i in range(self.num_dilation)])
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):

        x = x.permute(0, 3, 1, 2)  # B, C, H, W
        B, C, H, W = x.shape

        qkv = self.qkv(x).reshape(B, 3, self.num_dilation, C //self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)
        #num_dilation,3,B,C//num_dilation,H,W
        x = x.reshape(B, self.num_dilation, C//self.num_dilation, H, W).permute(1, 0, 3, 4, 2 )
        # num_dilation, B, H, W, C//num_dilation
        for i in range(self.num_dilation):
            x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])# B, H, W,C//num_dilation
        x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

0 人点赞