1.BAM介绍
论文:https://arxiv.org/pdf/1807.06514.pdf
摘要:提出了一种简单有效的注意力模块,称为瓶颈注意力模块(BAM),可以与任何前馈卷积神经网络集成。我们的模块沿着两条独立的路径,通道和空间,推断出一张注意力图。我们将我们的模块放置在模型的每个瓶颈处,在那里会发生特征图的下采样。我们的模块用许多参数在瓶颈处构建了分层注意力,并且它可以以端到端的方式与任何前馈模型联合训练。我们通过在CIFAR-100、ImageNet-1K、VOC 2007和MS COCO基准上进行大量实验来验证我们的BAM。我们的实验表明,各种模型在分类和检测性能上都有持续的改进,证明了BAM的广泛适用性。
作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标。
作者提出了新的Attention模型——瓶颈注意模块,通过分离的两个路径channel和spatial得到attention map,减少计算开销和参数开销。
2.BAM引入到yolov5
2.1 加入common.py
中:
代码语言:javascript复制###################### BAM attention #### START by AI&CV ###############################
import torch
from torch import nn
import torch.nn.functional as F
class ChannelGate(nn.Module):
def __init__(self, channel, reduction=16):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel)
)
self.bn = nn.BatchNorm1d(channel)
def forward(self, x):
b, c, h, w = x.shape
y = self.avgpool(x).view(b, c)
y = self.mlp(y)
y = self.bn(y).view(b, c, 1, 1)
return y.expand_as(x)
class SpatialGate(nn.Module):
def __init__(self, channel, reduction=16, kernel_size=3, dilation_val=4):
super().__init__()
self.conv1 = nn.Conv2d(channel, channel // reduction, kernel_size=1)
self.conv2 = nn.Sequential(
nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
dilation=dilation_val),
nn.BatchNorm2d(channel // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
dilation=dilation_val),
nn.BatchNorm2d(channel // reduction),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Conv2d(channel // reduction, 1, kernel_size=1)
self.bn = nn.BatchNorm2d(1)
def forward(self, x):
b, c, h, w = x.shape
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
y = self.bn(y)
return y.expand_as(x)
class BAM(nn.Module):
def __init__(self, channel):
super(BAM, self).__init__()
self.channel_attn = ChannelGate(channel)
self.spatial_attn = SpatialGate(channel)
def forward(self, x):
attn = F.sigmoid(self.channel_attn(x) self.spatial_attn(x))
return x x * attn
###################### BAM attention #### END by AI&CV ###############################
详见:https://blog.csdn.net/m0_63774211/article/details/131541363
by CSDN AI小怪兽
我正在参与2023腾讯技术创作特训营第三期有奖征文,组队打卡瓜分大奖!