|
extra_large的Unet网络结构:模型大小:1944756KB,1.85GB
- import torch.nn as nn
- import torch
- from torch import autograd
- import torchvision.models as models
- # 导入模型结构
- #resnet50 = models.resnet50(pretrained=True)
- # 加载预先下载好的预训练参数到resnet18
- #resnet50.load_state_dict(torch.load('resnet50-5c106cde.pth'))
- class DoubleConv(nn.Module):
- def __init__(self, in_ch, out_ch):
- super(DoubleConv, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True)
- )
- def forward(self, input):
- return self.conv(input)
- class Unet(nn.Module):
- def __init__(self,in_ch,out_ch):
- super(Unet, self).__init__()
- self.conv1 = DoubleConv(in_ch, 16)
- self.pool1 = nn.MaxPool2d(2)
- self.conv2 = DoubleConv(16, 32)
- self.pool2 = nn.MaxPool2d(2)
- self.conv3 = DoubleConv(32, 64)
- self.pool3 = nn.MaxPool2d(2)
- self.conv4 = DoubleConv(64, 128)
- self.pool4 = nn.MaxPool2d(2)
- self.conv5 = DoubleConv(128, 256)
- self.pool5 = nn.MaxPool2d(2)
- self.conv6 = DoubleConv(256, 512)
- self.pool6 = nn.MaxPool2d(2)
- self.conv7 = DoubleConv(512, 1024)
- self.pool7 = nn.MaxPool2d(2)
- self.conv8 = DoubleConv(1024, 2048)
- self.pool8 = nn.MaxPool2d(2)
- self.conv9 = DoubleConv(2048, 4096)
- self.up10 = nn.ConvTranspose2d(4096, 2048, 2, stride=2)
- self.conv10 = DoubleConv(4096, 2048)
- self.up11 = nn.ConvTranspose2d(2048, 1024, 2, stride=2)
- self.conv11 = DoubleConv(2048, 1024)
- self.up12 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
- self.conv12 = DoubleConv(1024, 512)
- self.up13 = nn.ConvTranspose2d(512, 256, 2, stride=2)
- self.conv13 = DoubleConv(512, 256)
- self.up14 = nn.ConvTranspose2d(256, 128, 2, stride=2)
- self.conv14 = DoubleConv(256, 128)
- self.up15 = nn.ConvTranspose2d(128, 64, 2, stride=2)
- self.conv15 = DoubleConv(128, 64)
- self.up16 = nn.ConvTranspose2d(64, 32, 2, stride=2)
- self.conv16 = DoubleConv(64, 32)
- self.up17 = nn.ConvTranspose2d(32, 16, 2, stride=2)
- self.conv17 = DoubleConv(32, 16)
- self.conv18 = nn.Conv2d(16,out_ch, 1)
- def forward(self,x):
- c1=self.conv1(x)
- p1=self.pool1(c1)
- c2=self.conv2(p1)
- p2=self.pool2(c2)
- c3=self.conv3(p2)
- p3=self.pool3(c3)
- c4=self.conv4(p3)
- p4=self.pool4(c4)
- c5=self.conv5(p4)
- p5=self.pool5(c5)
- c6=self.conv6(p5)
- p6=self.pool6(c6)
- c7=self.conv7(p6)
- p7=self.pool7(c7)
- c8=self.conv8(p7)
- p8=self.pool8(c8)
- c9=self.conv9(p8)
- up_10= self.up10(c9)
- merge10 = torch.cat([up_10, c8], dim=1)
- c10=self.conv10(merge10)
- up_11= self.up11(c10)
- merge11 = torch.cat([up_11, c7], dim=1)
- c11=self.conv11(merge11)
- up_12= self.up12(c11)
- merge12 = torch.cat([up_12, c6], dim=1)
- c12=self.conv12(merge12)
- up_13=self.up13(c12)
- merge13 = torch.cat([up_13, c5], dim=1)
- c13=self.conv13(merge13)
- up_14=self.up14(c13)
- merge14 = torch.cat([up_14, c4], dim=1)
- c14=self.conv14(merge14)
- up_15=self.up15(c14)
- merge15 = torch.cat([up_15, c3],dim=1)
- c15=self.conv15(merge15)
- up_16=self.up16(c15)
- merge16 = torch.cat([up_16, c2],dim=1)
- c16=self.conv16(merge16)
- up_17=self.up17(c16)
- merge17 = torch.cat([up_17, c1],dim=1)
- c17=self.conv17(merge17)
- c18=self.conv18(c17)
- out = nn.Sigmoid()(c18)
- return out
复制代码
|
|