阅读背景:

Unet网络结构搭建

来源:互联网 

from typing import Dict import torch import torch.nn as nn class DoubleConv(nn.Sequential): #定义两个串联卷积模块 def __init__(self, in_channels, out_channels, mid_channels=None): if mid_channels is None: #如果没有设置mid_channels,则mid_channels = out_channels mid_channels = out_channels super(DoubleConv, self).__init__( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) class Down(nn.Sequential): #定义下采样模块,最大池化+两个串联卷积模块 def __init__(self, in_channels, out_channels): super(Down, self).__init__( nn.MaxPool2d(2, stride=2), DoubleConv(in_channels, out_channels) ) class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super(Up, self).__init__() if bilinear: #双线性插值 self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: #转置卷积 self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) # Up正向传播过程,先进行上采样,在进行拼接,拼接之后在经过DoubleConv def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: x1 = self.up(x1) #此处只经过Upsample或ConvTranspose2d x = torch.cat([x2, x1], dim=1) #x2, x1进行拼接 x = self.conv(x) #拼接之后在经过DoubleConv return x class OutConv(nn.Sequential): def __init__(self, in_channels, num_classes): super(OutConv, self).__init__( nn.Conv2d(in_channels, num_classes, kernel_size=1) #1*1卷积调整最后的通道数 ) class UNet(nn.Module): def __init__(self, in_channels: int = 3, num_classes: int = 2, bilinear: bool = True, base_c: int = 64): super(UNet, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.bilinear = bilinear self.in_conv = DoubleConv(in_channels, base_c) self.down1 = Down(base_c, base_c * 2) self.down2 = Down(base_c * 2, base_c * 4) self.down3 = Down(base_c * 4, base_c * 8) factor = 2 if bilinear else 1 self.down4 = Down(base_c * 8, base_c * 16 // factor) self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear) self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear) self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear) self.up4 = Up(base_c * 2, base_c, bilinear) self.out_conv = OutConv(base_c, num_classes) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: x1 = self.in_conv(x) # 480*480*3 -> 480*480*64 -> 480*480*64 x2 = self.down1(x1) # 240*240*64 -> 240*240*128 -> 240*240*128 x3 = self.down2(x2) # 120*120*128 -> 120*120*256 -> 120*120*256 x4 = self.down3(x3) # 60*60*256 -> 60*60*512 -> 60*60*512 x5 = self.down4(x4) # 30*30*512 -> 30*30*512 -> 30*30*512 x = self.up1(x5, x4) # 60*60*512 -> 60*60*1024 -> 60*60*512 -> 60*60*256 x = self.up2(x, x3) # 120*120*256 -> 120*120*512 -> 120*120*256 -> 120*120*128 x = self.up3(x, x2) # 240*240*128 -> 240*240*256 -> 240*240*128 -> 240*240*64 x = self.up4(x, x1) # 480*480*64 -> 480*480*128 -> 480*480*64 -> 480*480*64 logits = self.out_conv(x) # 480*480*64-> 480*480*num_classes return {"out": logits} reference from typing import Dict import torch import torch



你的当前访问异常,请进行认证后继续阅读剩余内容。

分享到: