我们这里所说的注意力机制一般指的是软注意力 (soft attention)。
有一种Attention机制为对各个feature map通道进行加权,可以参考Tensorflow的图像操作(四) 的SENet,这里我们主要讨论的是Self-Attention。
上图就是Self-Attention机制的基本结构,最左边的feature maps来自卷积层的降采样的输出,通常为原输入图像尺寸的1/8。然后通过3个1*1的卷积核分别对该feature map进行卷积,这里是一个级联操作。第一个f和第二个g分别将feature map的通道数变为原来的1/8,而第三个h则保持原通道数不变。
这里的将f的输出转置后与g的输出进行点乘是为了进行相似度计算得到权重,然后进行softmax归一化。将归一化的权重和相应的h进行加权求和,得到最后的attention。
Pytorch实现
代码语言:javascript复制import torch.nn as nn
import torch
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super(Self_Attn, self).__init__()
self.chanel_in = in_dim
self.f = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.g = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.h = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = torch.zeros(1, requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps (B X C X W X H)
returns :
out : self attention value input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size()
f1 = self.f(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
g1 = self.g(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
energy = torch.bmm(f1, g1) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
h1 = self.h(x).view(m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(h1, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out x
return out, attention
if __name__ == '__main__':
a = torch.rand(1, 512, 64, 64)
self_atten = Self_Attn(512)
out, atten = self_atten(a)
print(out)
print(atten)
Tensorflow实现
代码语言:javascript复制import tensorflow as tf
from tensorflow.keras import models, layers
import numpy as np
import Scale
class Self_Attn(models.Model):
def __init__(self, in_dim):
super(Self_Attn, self).__init__()
self.f = layers.Conv2D(in_dim // 8, (1, 1))
self.g = layers.Conv2D(in_dim // 8, (1, 1))
self.h = layers.Conv2D(in_dim, (1, 1))
def call(self, x):
batchsize, width, height, channel = x.shape
f1 = self.f(x)
f1 = tf.reshape(f1, (batchsize, -1, width * height))
f1 = tf.transpose(f1, (0, 2, 1))
g1 = self.g(x)
g1 = tf.reshape(g1, (batchsize, -1, width * height))
h1 = self.h(x)
h1 = tf.reshape(h1, (batchsize, -1, width * height))
energy = tf.matmul(f1, g1)
atten = tf.nn.softmax(energy, axis=-1)
atten = tf.transpose(atten, (0, 2, 1))
out = tf.matmul(h1, atten)
out = tf.reshape(out, (batchsize, width, height, channel))
out = Scale.Scale()(out)
out = out x
return out, atten
if __name__ == '__main__':
a = tf.constant(np.random.rand(1, 64, 64, 512), dtype=tf.float32)
self_atten = Self_Attn(512)
out, atten = self_atten(a)
print(out)
print(atten)
其中Scale代码如下
代码语言:javascript复制from tensorflow.keras import layers
class Scale(layers.Layer):
def __init__(self, **kwargs):
super(Scale, self).__init__(**kwargs)
def build(self):
self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zero', trainable=True)
def call(self, x, mask=None):
return self.gamma * x