YoloV5/YoloV7改进---注意力机制:引入瓶颈注意力模块BAM,对标CBAM

2023-11-30 16:42:40 浏览数 (1)

1.BAM介绍

论文:https://arxiv.org/pdf/1807.06514.pdf

摘要:提出了一种简单有效的注意力模块,称为瓶颈注意力模块(BAM),可以与任何前馈卷积神经网络集成。我们的模块沿着两条独立的路径,通道和空间,推断出一张注意力图。我们将我们的模块放置在模型的每个瓶颈处,在那里会发生特征图的下采样。我们的模块用许多参数在瓶颈处构建了分层注意力,并且它可以以端到端的方式与任何前馈模型联合训练。我们通过在CIFAR-100、ImageNet-1K、VOC 2007和MS COCO基准上进行大量实验来验证我们的BAM。我们的实验表明,各种模型在分类和检测性能上都有持续的改进,证明了BAM的广泛适用性。

作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标。

作者提出了新的Attention模型——瓶颈注意模块,通过分离的两个路径channel和spatial得到attention map,减少计算开销和参数开销。

2.BAM引入到yolov5

2.1 加入common.py中:

代码语言:javascript复制
###################### BAM  attention  ####     START   by  AI&CV  ###############################

import torch
from torch import nn
import torch.nn.functional as F


class ChannelGate(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        self.bn = nn.BatchNorm1d(channel)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.avgpool(x).view(b, c)
        y = self.mlp(y)
        y = self.bn(y).view(b, c, 1, 1)
        return y.expand_as(x)


class SpatialGate(nn.Module):
    def __init__(self, channel, reduction=16, kernel_size=3, dilation_val=4):
        super().__init__()
        self.conv1 = nn.Conv2d(channel, channel // reduction, kernel_size=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Conv2d(channel // reduction, 1, kernel_size=1)
        self.bn = nn.BatchNorm2d(1)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.bn(y)
        return y.expand_as(x)


class BAM(nn.Module):
    def __init__(self, channel):
        super(BAM, self).__init__()
        self.channel_attn = ChannelGate(channel)
        self.spatial_attn = SpatialGate(channel)

    def forward(self, x):
        attn = F.sigmoid(self.channel_attn(x)   self.spatial_attn(x))
        return x   x * attn

###################### BAM  attention  ####     END   by  AI&CV  ###############################

详见:https://blog.csdn.net/m0_63774211/article/details/131541363

by CSDN AI小怪兽

我正在参与2023腾讯技术创作特训营第三期有奖征文,组队打卡瓜分大奖!

0 人点赞