Swin-Unet最强分割网络

2022-06-12 14:41:56 浏览数 (2)

Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络

它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为2,2,2,1,而不是Swin Transformer的2,2,6,2。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。

我们来看一下Patch Expanding的代码实现

代码语言:javascript复制
from einops import rearrange
代码语言:javascript复制
class PatchExpand(nn.Module):
    """
    块状扩充,尺寸翻倍,通道数减半
    """
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        """
        Args:
            input_resolution: 解码过程的feature map的宽高
            dim: frature map通道数
            dim_scale: 通道数扩充的倍数
            norm_layer: 通道方向归一化
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        # 通过全连接层来扩大通道数
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        # 先把通道数翻倍
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        # 将各个通道分开,再将所有通道拼成一个feature map
        # 增大了feature map的尺寸
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
        # 通道翻倍后再除以4,实际相当于通道数减半
        x = x.view(B, -1, C // 4)
        x = self.norm(x)

        return x

0 人点赞