请选择 进入手机版 | 继续访问电脑版

Hello Mat

 找回密码
 立即注册

QQ登录

只需一步,快速开始

查看: 127|回复: 0

Pytorch_Unet图像分割

[复制链接]

13

主题

25

帖子

895

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
895
发表于 7 天前 | 显示全部楼层 |阅读模式
Pytorch_Unet图像分割:
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\u_net_liver-master")
  4. import numpy as np
  5. import torch
  6. import argparse
  7. from torch.utils.data import DataLoader
  8. from torch import autograd, optim
  9. from torchvision.transforms import transforms
  10. from unet import Unet
  11. from dataset import LiverDataset
  12. import cv2
  13. import random

  14. # 是否使用cuda
  15. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16. #device = torch.device("cpu")
  17. #print(device)

  18. x_transforms = transforms.Compose([
  19.     transforms.ToTensor(),
  20.     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  21. ])

  22. # mask只需要转换为tensor
  23. y_transforms = transforms.ToTensor()

  24. #参数解析
  25. parse=argparse.ArgumentParser()

  26. def train_model(model, criterion, optimizer, dataload, num_epochs=20):
  27.     model.cuda(0)
  28.     for epoch in range(num_epochs):
  29.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  30.         print('-' * 10)
  31.         dt_size = len(dataload.dataset)
  32.         epoch_loss = 0
  33.         step = 0
  34.         #with torch.no_grad():
  35.         for x, y in dataload:
  36.             step += 1
  37.             inputs = x.to(device)
  38.             labels = y.to(device)
  39.             del x,y
  40.             # zero the parameter gradients
  41.             optimizer.zero_grad()
  42.             # forward
  43.             outputs = model(inputs)
  44.             loss = criterion(outputs, labels)
  45.             loss.backward()
  46.             optimizer.step()
  47.             
  48.             epoch_loss += loss.item()
  49.             print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
  50.             
  51.             del inputs,labels,outputs
  52.             torch.cuda.empty_cache()
  53.             
  54.         print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
  55.     torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
  56.     return model

  57. #训练模型
  58. def train():
  59.     model = Unet(3, 1).to(device)
  60.     batch_size = args.batch_size
  61.     criterion = torch.nn.BCELoss()
  62.     optimizer = optim.Adam(model.parameters())
  63.     liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
  64.     dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  65.     train_model(model, criterion, optimizer, dataloaders)

  66. #显示模型的输出结果
  67. def test_gpu():
  68.     model = Unet(3, 1).to(device)
  69. #    model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
  70.     model.load_state_dict(torch.load(args.ckp, map_location='cuda'))
  71.     liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
  72.     dataloaders = DataLoader(liver_dataset, batch_size=1)
  73.     model.cuda(0)
  74.     model.eval()
  75.     import matplotlib.pyplot as plt
  76.     plt.ion()
  77.     with torch.no_grad():
  78.         for x, _ in dataloaders:
  79.             y=model(x.to(device))
  80.             img_y=torch.squeeze(y).cpu().numpy()
  81.             plt.imshow(img_y)
  82.             plt.pause(0.01)
  83.             
  84.             torch.cuda.empty_cache()
  85.             
  86.             cv2.imwrite('./resultsImages/'+str(random.random())+'.tiff', img_y)

  87.         plt.show()
  88.     torch.cuda.empty_cache()
  89.    
  90. def test_cpu():
  91.     model = Unet(3, 1)
  92.     model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
  93.     liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
  94.     dataloaders = DataLoader(liver_dataset, batch_size=1)
  95.     model.eval()
  96.     import matplotlib.pyplot as plt
  97.     plt.ion()
  98.     with torch.no_grad():
  99.         for x, _ in dataloaders:
  100.             y=model(x)
  101.             img_y=torch.squeeze(y).numpy()
  102.             plt.imshow(img_y)
  103.             plt.pause(0.01)
  104.         plt.show()
  105.         
  106.     torch.cuda.empty_cache()
  107.    
  108. if __name__ == '__main__':
  109.     parse = argparse.ArgumentParser()
  110. #    parse.add_argument("action", type=str, help="train or test")
  111.     parse.add_argument("--action", type=str, help="train or test", default="train")
  112.     parse.add_argument("--batch_size", type=int, default=4)
  113. #    parse.add_argument("--ckp", type=str, help="the path of model weight file")
  114.     parse.add_argument("--ckp", type=str, help="the path of model weight file", default="weights_19.pth")
  115.     args = parse.parse_args()

  116.     print(args.action)

  117.     if args.action=="train":
  118.         train()
  119.     elif args.action=="test":
  120.         test_gpu()
  121.         
  122.     test_gpu()
  123.    
复制代码
尤其要注意的是:
del inputs,labels,outputs
torch.cuda.empty_cache()
清理GPU缓存,不然报GPU内存不够错误。

unet.py如下:
  1. import torch.nn as nn
  2. import torch
  3. from torch import autograd

  4. class DoubleConv(nn.Module):
  5.     def __init__(self, in_ch, out_ch):
  6.         super(DoubleConv, self).__init__()
  7.         self.conv = nn.Sequential(
  8.             nn.Conv2d(in_ch, out_ch, 3, padding=1),
  9.             nn.BatchNorm2d(out_ch),
  10.             nn.ReLU(inplace=True),
  11.             nn.Conv2d(out_ch, out_ch, 3, padding=1),
  12.             nn.BatchNorm2d(out_ch),
  13.             nn.ReLU(inplace=True)
  14.         )

  15.     def forward(self, input):
  16.         return self.conv(input)


  17. class Unet(nn.Module):
  18.     def __init__(self,in_ch,out_ch):
  19.         super(Unet, self).__init__()

  20.         self.conv1 = DoubleConv(in_ch, 64)
  21.         self.pool1 = nn.MaxPool2d(2)
  22.         self.conv2 = DoubleConv(64, 128)
  23.         self.pool2 = nn.MaxPool2d(2)
  24.         self.conv3 = DoubleConv(128, 256)
  25.         self.pool3 = nn.MaxPool2d(2)
  26.         self.conv4 = DoubleConv(256, 512)
  27.         self.pool4 = nn.MaxPool2d(2)
  28.         self.conv5 = DoubleConv(512, 1024)
  29.         self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  30.         self.conv6 = DoubleConv(1024, 512)
  31.         self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  32.         self.conv7 = DoubleConv(512, 256)
  33.         self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  34.         self.conv8 = DoubleConv(256, 128)
  35.         self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  36.         self.conv9 = DoubleConv(128, 64)
  37.         self.conv10 = nn.Conv2d(64,out_ch, 1)

  38.     def forward(self,x):
  39.         c1=self.conv1(x)
  40.         p1=self.pool1(c1)
  41.         c2=self.conv2(p1)
  42.         p2=self.pool2(c2)
  43.         c3=self.conv3(p2)
  44.         p3=self.pool3(c3)
  45.         c4=self.conv4(p3)
  46.         p4=self.pool4(c4)
  47.         c5=self.conv5(p4)
  48.         up_6= self.up6(c5)
  49.         merge6 = torch.cat([up_6, c4], dim=1)
  50.         c6=self.conv6(merge6)
  51.         up_7=self.up7(c6)
  52.         merge7 = torch.cat([up_7, c3], dim=1)
  53.         c7=self.conv7(merge7)
  54.         up_8=self.up8(c7)
  55.         merge8 = torch.cat([up_8, c2], dim=1)
  56.         c8=self.conv8(merge8)
  57.         up_9=self.up9(c8)
  58.         merge9=torch.cat([up_9,c1],dim=1)
  59.         c9=self.conv9(merge9)
  60.         c10=self.conv10(c9)
  61.         out = nn.Sigmoid()(c10)
  62.         return out
复制代码








回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则


Python|Opencv|MATLAB|Halcom.cn  

GMT+8, 2019-5-19 20:43 , Processed in 0.132084 second(s), 25 queries .

Powered by Discuz! X3.2

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表