使用pytorch实现论文中的unet网络

2020-10-21 14:53:19 浏览数 (1)

设计神经网络的一般步骤:

1. 设计框架

2. 设计骨干网络

Unet网络设计的步骤:

1. 设计Unet网络工厂模式

2. 设计编解码结构

3. 设计卷积模块

4. unet实例模块

Unet网络最重要的特征:

1. 编解码结构。

2. 解码结构,比FCN更加完善,采用连接方式。

3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码:

代码语言:javascript复制
import torch
import torch.nn as nn
class Unet(nn.Module):
#初始化参数:Encoder,Decoder,bridge
#bridge默认值为无,如果有参数传入,则用该参数替换None
def __init__(self,Encoder,Decoder,bridge = None):
super(Unet,self).__init__()
self.encoder = Encoder(encoder_blocks)
self.decoder = Decoder(decoder_blocks)
self.bridge = bridge
def forward(self,x):
res = self.encoder(x)
out,skip = res[0],res[1,:]
if bridge is not None:
out = bridge(out)
out = self.decoder(out,skip)
return out
#设计编码模块
class Encoder(nn.Module):
def __init__(self,blocks):
super(Encoder,self).__init__()
#assert:断言函数,避免出现参数错误
assert len(blocks)   0
#nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
self.blocks = nn.Modulelist(blocks)
def forward(self,x):
skip = []
for i in range(len(self.blocks) - 1):
x = self.blocks[i](x)
skip.append(x)
res = [self.block[i 1](x)]
#列表之间可以通过 号拼接
res  = skip
return res
#设计Decoder模块
class Decoder(nn.Module):
def __init__(self,blocks):
super(Decoder, self).__init__()
assert len(blocks)   0
self.blocks = nn.Modulelist(blocks)
def ceter_crop(self,skips,x):
_,_,height1,width1 = skips.shape()
_,_,height2,width2 = x.shape()
#对图像进行剪切处理,拼接的时候保持对应size参数一致
ht,wt = min(height1,height2),min(width1,width2)
dh1 = (height1 - height2)//2 if height1   height2 else 0
dw1 = (width1 - width2)//2 if width1   width2 else 0
dh2 = (height2 - height1)//2 if height2   height1 else 0
dw2 = (width2 - width1)//2 if width2   width1 else 0
return skips[:,:,dh1:(dh1   ht),dw1:(dw1   wt)],
x[:,:,dh2:(dh2   ht),dw2 : (dw2   wt)]
def forward(self, skips,x,reverse_skips = True):
assert len(skips) == len(blocks) - 1
if reverse_skips is True:
skips = skips[: : -1]
x = self.blocks[0](x)
for i in range(1, len(self.blocks)):
skip = skips[i-1]
x = torch.cat(skip,x,1)
x = self.blocks[i](x)
return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0):
#nn.Sequential:与Modulelist相比,包含了forward函数
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True),
)
#实例化Unet模型
def unet(in_channels,out_channels):
encoder_blocks = [unet_convs(in_channels, 64),
nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),
unet_convs(64,128)), 
nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), 
unet_convs(128, 256)),
nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), 
unet_convs(256, 512)),
]
bridge = nn.Sequential(unet_convs(512, 1024))
decoder_blocks = [nn.conTranpose2d(1024, 512), 
nn.Sequential(unet_convs(1024, 512),
nn.conTranpose2d(512, 256)),
nn.Sequential(unet_convs(512, 256),
nn.conTranpose2d(256, 128)), 
nn.Sequential(unet_convs(512, 256),
nn.conTranpose2d(256, 128)), 
nn.Sequential(unet_convs(256, 128),
nn.conTranpose2d(128, 64))
]
return Unet(encoder_blocks,decoder_blocks,bridge)

补充知识:Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation

代码语言:javascript复制
import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# 逆卷积,也可以使用上采样
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
c1 = self.conv1(x)
crop1 = c1[:,:,88:480,88:480]
p1 = self.pool1(c1)
c2 = self.conv2(p1)
crop2 = c2[:,:,40:240,40:240]
p2 = self.pool2(c2)
c3 = self.conv3(p2)
crop3 = c3[:,:,16:120,16:120]
p3 = self.pool3(c3)
c4 = self.conv4(p3)
crop4 = c4[:,:,4:60,4:60]
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, crop4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, crop3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, crop2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, crop1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
if __name__=="__main__":
test_input=torch.rand(1, 1, 572, 572)
model=Unet(in_ch=1, out_ch=2)
summary(model, (1,572,572))
ouput=model(test_input)
print(ouput.size())
代码语言:javascript复制
----------------------------------------------------------------
Layer (type)    Output Shape   Param #
================================================================
Conv2d-1   [-1, 64, 570, 570]    640
BatchNorm2d-2   [-1, 64, 570, 570]    128
ReLU-3   [-1, 64, 570, 570]    0
Conv2d-4   [-1, 64, 568, 568]   36,928
BatchNorm2d-5   [-1, 64, 568, 568]    128
ReLU-6   [-1, 64, 568, 568]    0
DoubleConv-7   [-1, 64, 568, 568]    0
MaxPool2d-8   [-1, 64, 284, 284]    0
Conv2d-9  [-1, 128, 282, 282]   73,856
BatchNorm2d-10  [-1, 128, 282, 282]    256
ReLU-11  [-1, 128, 282, 282]    0
Conv2d-12  [-1, 128, 280, 280]   147,584
BatchNorm2d-13  [-1, 128, 280, 280]    256
ReLU-14  [-1, 128, 280, 280]    0
DoubleConv-15  [-1, 128, 280, 280]    0
MaxPool2d-16  [-1, 128, 140, 140]    0
Conv2d-17  [-1, 256, 138, 138]   295,168
BatchNorm2d-18  [-1, 256, 138, 138]    512
ReLU-19  [-1, 256, 138, 138]    0
Conv2d-20  [-1, 256, 136, 136]   590,080
BatchNorm2d-21  [-1, 256, 136, 136]    512
ReLU-22  [-1, 256, 136, 136]    0
DoubleConv-23  [-1, 256, 136, 136]    0
MaxPool2d-24   [-1, 256, 68, 68]    0
Conv2d-25   [-1, 512, 66, 66]  1,180,160
BatchNorm2d-26   [-1, 512, 66, 66]   1,024
ReLU-27   [-1, 512, 66, 66]    0
Conv2d-28   [-1, 512, 64, 64]  2,359,808
BatchNorm2d-29   [-1, 512, 64, 64]   1,024
ReLU-30   [-1, 512, 64, 64]    0
DoubleConv-31   [-1, 512, 64, 64]    0
MaxPool2d-32   [-1, 512, 32, 32]    0
Conv2d-33   [-1, 1024, 30, 30]  4,719,616
BatchNorm2d-34   [-1, 1024, 30, 30]   2,048
ReLU-35   [-1, 1024, 30, 30]    0
Conv2d-36   [-1, 1024, 28, 28]  9,438,208
BatchNorm2d-37   [-1, 1024, 28, 28]   2,048
ReLU-38   [-1, 1024, 28, 28]    0
DoubleConv-39   [-1, 1024, 28, 28]    0
ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664
Conv2d-41   [-1, 512, 54, 54]  4,719,104
BatchNorm2d-42   [-1, 512, 54, 54]   1,024
ReLU-43   [-1, 512, 54, 54]    0
Conv2d-44   [-1, 512, 52, 52]  2,359,808
BatchNorm2d-45   [-1, 512, 52, 52]   1,024
ReLU-46   [-1, 512, 52, 52]    0
DoubleConv-47   [-1, 512, 52, 52]    0
ConvTranspose2d-48  [-1, 256, 104, 104]   524,544
Conv2d-49  [-1, 256, 102, 102]  1,179,904
BatchNorm2d-50  [-1, 256, 102, 102]    512
ReLU-51  [-1, 256, 102, 102]    0
Conv2d-52  [-1, 256, 100, 100]   590,080
BatchNorm2d-53  [-1, 256, 100, 100]    512
ReLU-54  [-1, 256, 100, 100]    0
DoubleConv-55  [-1, 256, 100, 100]    0
ConvTranspose2d-56  [-1, 128, 200, 200]   131,200
Conv2d-57  [-1, 128, 198, 198]   295,040
BatchNorm2d-58  [-1, 128, 198, 198]    256
ReLU-59  [-1, 128, 198, 198]    0
Conv2d-60  [-1, 128, 196, 196]   147,584
BatchNorm2d-61  [-1, 128, 196, 196]    256
ReLU-62  [-1, 128, 196, 196]    0
DoubleConv-63  [-1, 128, 196, 196]    0
ConvTranspose2d-64   [-1, 64, 392, 392]   32,832
Conv2d-65   [-1, 64, 390, 390]   73,792
BatchNorm2d-66   [-1, 64, 390, 390]    128
ReLU-67   [-1, 64, 390, 390]    0
Conv2d-68   [-1, 64, 388, 388]   36,928
BatchNorm2d-69   [-1, 64, 388, 388]    128
ReLU-70   [-1, 64, 388, 388]    0
DoubleConv-71   [-1, 64, 388, 388]    0
Conv2d-72   [-1, 2, 388, 388]    130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])

以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考。

0 人点赞