大家好,又见面了,我是你们的朋友全栈君。
ResNet18的网络架构图
首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。
代码如下:
代码语言:javascript复制import torch
from torch import nn
from torch.nn import functional as F
class BasicBlock(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride):
super(BasicBlock,self).__init__()
self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=1)
self.bn1=nn.BatchNorm2d(out_channels)
self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding=1)
self.bn2=nn.BatchNorm2d(out_channels)
def forward(self,x):
output=self.bn1(self.conv1(x))
output=self.bn2(self.conv2(output))
return F.relu(x output)
class BasicDownBlock(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride):
super(BasicDownBlock,self).__init__()
self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size[0],stride[0],padding=1)
self.bn1=nn.BatchNorm2d(out_channels)
self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size[0],stride[1],padding=1)
self.bn2=nn.BatchNorm2d(out_channels)
self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size[1],stride[0])
self.bn3=nn.BatchNorm2d(out_channels)
def forward(self,x):
output=self.bn1(self.conv1(x))
output=self.bn2(self.conv2(output))
output1=self.bn3(self.conv3(x))
return F.relu(output1 output)
class ResNet18(nn.Module):
def __init__(self):
super().__init__()
# 3x224x224-->64x112x112
self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)
self.bn1=nn.BatchNorm2d(64)
# 64x112x112-->64x56x56
self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
# 64x56x56-->64x56x56
self.layer1=nn.Sequential(
BasicBlock(64,64,3,1),
BasicBlock(64,64,3,1)
)
# 64x56x56-->128*28*28
self.layer2=nn.Sequential(
BasicDownBlock(64,128,[3,1],[2,1]),
BasicBlock(128,128,3,1)
)
# 128*28*28-->256*14*14
self.layer3=nn.Sequential(
BasicDownBlock(128,256,[3,1],[2,1]),
BasicBlock(256,256,3,1)
)
# 256*14*14-->512x7x7
self.layer4=nn.Sequential(
BasicDownBlock(256,512,[7,1],[2,1]),
BasicBlock(512,512,3,1)
)
# 512x7x7-->512x1x1
self.avgpool=nn.AdaptiveMaxPool2d(output_size=(1,1))
self.flat=nn.Flatten()
self.linear=nn.Linear(512,10)
def forward(self,x):
output=self.pool1(F.relu(self.bn1(self.conv1(x))))
output=self.layer1(output)
output=self.layer2(output)
output=self.layer3(output)
output=self.layer4(output)
output=self.avgpool(output)
output=self.flat(output)
output=self.linear(output)
return output
net=ResNet18()
x=torch.randn(32,3,224,224)
print(x.shape)
y=net(x)
print(y.shape)
代码中BasicBlock为普通的残差块,注意步长和卷积核的大小,BasicDownBlock为下采样的残差块,然后将四层的网络表示出来,最后进行验证x.shape为torch.Size([32, 3, 224, 224]),y.shape为torch.Size([32, 10])。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141545.html原文链接:https://javaforall.cn