Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络
它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为2,2,2,1,而不是Swin Transformer的2,2,6,2。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。
我们来看一下Patch Expanding的代码实现
代码语言:javascript复制from einops import rearrange
代码语言:javascript复制class PatchExpand(nn.Module):
"""
块状扩充,尺寸翻倍,通道数减半
"""
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
"""
Args:
input_resolution: 解码过程的feature map的宽高
dim: frature map通道数
dim_scale: 通道数扩充的倍数
norm_layer: 通道方向归一化
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
# 通过全连接层来扩大通道数
self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
# 先把通道数翻倍
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# 将各个通道分开,再将所有通道拼成一个feature map
# 增大了feature map的尺寸
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
# 通道翻倍后再除以4,实际相当于通道数减半
x = x.view(B, -1, C // 4)
x = self.norm(x)
return x