Hello Mat

 找回密码
 立即注册
查看: 6018|回复: 0

Pytorch_Unet图像分割

[复制链接]

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
发表于 2019-5-12 23:31:34 | 显示全部楼层 |阅读模式
Pytorch_Unet图像分割:
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\u_net_liver-master")
  4. import numpy as np
  5. import torch
  6. import argparse
  7. from torch.utils.data import DataLoader
  8. from torch import autograd, optim
  9. from torchvision.transforms import transforms
  10. from unet import Unet
  11. from dataset import LiverDataset
  12. import cv2
  13. import random

  14. # 是否使用cuda
  15. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16. #device = torch.device("cpu")
  17. #print(device)

  18. x_transforms = transforms.Compose([
  19.     transforms.ToTensor(),
  20.     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  21. ])

  22. # mask只需要转换为tensor
  23. y_transforms = transforms.ToTensor()

  24. #参数解析
  25. parse=argparse.ArgumentParser()

  26. def train_model(model, criterion, optimizer, dataload, num_epochs=20):
  27.     model.cuda(0)
  28.     for epoch in range(num_epochs):
  29.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  30.         print('-' * 10)
  31.         dt_size = len(dataload.dataset)
  32.         epoch_loss = 0
  33.         step = 0
  34.         #with torch.no_grad():
  35.         for x, y in dataload:
  36.             step += 1
  37.             inputs = x.to(device)
  38.             labels = y.to(device)
  39.             del x,y
  40.             # zero the parameter gradients
  41.             optimizer.zero_grad()
  42.             # forward
  43.             outputs = model(inputs)
  44.             loss = criterion(outputs, labels)
  45.             loss.backward()
  46.             optimizer.step()
  47.             
  48.             epoch_loss += loss.item()
  49.             print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
  50.             
  51.             del inputs,labels,outputs
  52.             torch.cuda.empty_cache()
  53.             
  54.         print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
  55.     torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
  56.     return model

  57. #训练模型
  58. def train():
  59.     model = Unet(3, 1).to(device)
  60.     batch_size = args.batch_size
  61.     criterion = torch.nn.BCELoss()
  62.     optimizer = optim.Adam(model.parameters())
  63.     liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
  64.     dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  65.     train_model(model, criterion, optimizer, dataloaders)

  66. #显示模型的输出结果
  67. def test_gpu():
  68.     model = Unet(3, 1).to(device)
  69. #    model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
  70.     model.load_state_dict(torch.load(args.ckp, map_location='cuda'))
  71.     liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
  72.     dataloaders = DataLoader(liver_dataset, batch_size=1)
  73.     model.cuda(0)
  74.     model.eval()
  75.     import matplotlib.pyplot as plt
  76.     plt.ion()
  77.     with torch.no_grad():
  78.         for x, _ in dataloaders:
  79.             y=model(x.to(device))
  80.             img_y=torch.squeeze(y).cpu().numpy()
  81.             plt.imshow(img_y)
  82.             plt.pause(0.01)
  83.             
  84.             torch.cuda.empty_cache()
  85.             
  86.             cv2.imwrite('./resultsImages/'+str(random.random())+'.tiff', img_y)

  87.         plt.show()
  88.     torch.cuda.empty_cache()
  89.    
  90. def test_cpu():
  91.     model = Unet(3, 1)
  92.     model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
  93.     liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
  94.     dataloaders = DataLoader(liver_dataset, batch_size=1)
  95.     model.eval()
  96.     import matplotlib.pyplot as plt
  97.     plt.ion()
  98.     with torch.no_grad():
  99.         for x, _ in dataloaders:
  100.             y=model(x)
  101.             img_y=torch.squeeze(y).numpy()
  102.             plt.imshow(img_y)
  103.             plt.pause(0.01)
  104.         plt.show()
  105.         
  106.     torch.cuda.empty_cache()
  107.    
  108. if __name__ == '__main__':
  109.     parse = argparse.ArgumentParser()
  110. #    parse.add_argument("action", type=str, help="train or test")
  111.     parse.add_argument("--action", type=str, help="train or test", default="train")
  112.     parse.add_argument("--batch_size", type=int, default=4)
  113. #    parse.add_argument("--ckp", type=str, help="the path of model weight file")
  114.     parse.add_argument("--ckp", type=str, help="the path of model weight file", default="weights_19.pth")
  115.     args = parse.parse_args()

  116.     print(args.action)

  117.     if args.action=="train":
  118.         train()
  119.     elif args.action=="test":
  120.         test_gpu()
  121.         
  122.     test_gpu()
  123.    
复制代码
尤其要注意的是:
del inputs,labels,outputs
torch.cuda.empty_cache()
清理GPU缓存,不然报GPU内存不够错误。

unet.py如下:
  1. import torch.nn as nn
  2. import torch
  3. from torch import autograd

  4. class DoubleConv(nn.Module):
  5.     def __init__(self, in_ch, out_ch):
  6.         super(DoubleConv, self).__init__()
  7.         self.conv = nn.Sequential(
  8.             nn.Conv2d(in_ch, out_ch, 3, padding=1),
  9.             nn.BatchNorm2d(out_ch),
  10.             nn.ReLU(inplace=True),
  11.             nn.Conv2d(out_ch, out_ch, 3, padding=1),
  12.             nn.BatchNorm2d(out_ch),
  13.             nn.ReLU(inplace=True)
  14.         )

  15.     def forward(self, input):
  16.         return self.conv(input)


  17. class Unet(nn.Module):
  18.     def __init__(self,in_ch,out_ch):
  19.         super(Unet, self).__init__()

  20.         self.conv1 = DoubleConv(in_ch, 64)
  21.         self.pool1 = nn.MaxPool2d(2)
  22.         self.conv2 = DoubleConv(64, 128)
  23.         self.pool2 = nn.MaxPool2d(2)
  24.         self.conv3 = DoubleConv(128, 256)
  25.         self.pool3 = nn.MaxPool2d(2)
  26.         self.conv4 = DoubleConv(256, 512)
  27.         self.pool4 = nn.MaxPool2d(2)
  28.         self.conv5 = DoubleConv(512, 1024)
  29.         self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  30.         self.conv6 = DoubleConv(1024, 512)
  31.         self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  32.         self.conv7 = DoubleConv(512, 256)
  33.         self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  34.         self.conv8 = DoubleConv(256, 128)
  35.         self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  36.         self.conv9 = DoubleConv(128, 64)
  37.         self.conv10 = nn.Conv2d(64,out_ch, 1)

  38.     def forward(self,x):
  39.         c1=self.conv1(x)
  40.         p1=self.pool1(c1)
  41.         c2=self.conv2(p1)
  42.         p2=self.pool2(c2)
  43.         c3=self.conv3(p2)
  44.         p3=self.pool3(c3)
  45.         c4=self.conv4(p3)
  46.         p4=self.pool4(c4)
  47.         c5=self.conv5(p4)
  48.         up_6= self.up6(c5)
  49.         merge6 = torch.cat([up_6, c4], dim=1)
  50.         c6=self.conv6(merge6)
  51.         up_7=self.up7(c6)
  52.         merge7 = torch.cat([up_7, c3], dim=1)
  53.         c7=self.conv7(merge7)
  54.         up_8=self.up8(c7)
  55.         merge8 = torch.cat([up_8, c2], dim=1)
  56.         c8=self.conv8(merge8)
  57.         up_9=self.up9(c8)
  58.         merge9=torch.cat([up_9,c1],dim=1)
  59.         c9=self.conv9(merge9)
  60.         c10=self.conv10(c9)
  61.         out = nn.Sigmoid()(c10)
  62.         return out
复制代码

参考:
【1】读取文件夹内图像--方法1
【2】读取文件夹内图像--方法2

【3】https://github.com/gupta-abhay/pytorch-modelzoo
【4】https://github.com/yt4766269/pytorch_zoo
【5】https://github.com/takahiro-itazuri/model-zoo-pytorch
【6】http://222.195.93.137/gitlab/winston.wen/kaggle-1
【7】3-6之百度网盘下载链接:https://pan.baidu.com/s/17pXf2M3lIAFkJuoMXOZ9pA 提取码:k0ou
















回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-22 23:12 , Processed in 0.211447 second(s), 22 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表