|
Pytorch_Unet图像分割:
- # -*- coding: utf-8 -*-
- import sys
- sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\u_net_liver-master")
- import numpy as np
- import torch
- import argparse
- from torch.utils.data import DataLoader
- from torch import autograd, optim
- from torchvision.transforms import transforms
- from unet import Unet
- from dataset import LiverDataset
- import cv2
- import random
- # 是否使用cuda
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- #device = torch.device("cpu")
- #print(device)
- x_transforms = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
- ])
- # mask只需要转换为tensor
- y_transforms = transforms.ToTensor()
- #参数解析
- parse=argparse.ArgumentParser()
- def train_model(model, criterion, optimizer, dataload, num_epochs=20):
- model.cuda(0)
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- print('-' * 10)
- dt_size = len(dataload.dataset)
- epoch_loss = 0
- step = 0
- #with torch.no_grad():
- for x, y in dataload:
- step += 1
- inputs = x.to(device)
- labels = y.to(device)
- del x,y
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
-
- epoch_loss += loss.item()
- print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
-
- del inputs,labels,outputs
- torch.cuda.empty_cache()
-
- print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
- torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
- return model
- #训练模型
- def train():
- model = Unet(3, 1).to(device)
- batch_size = args.batch_size
- criterion = torch.nn.BCELoss()
- optimizer = optim.Adam(model.parameters())
- liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
- dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
- train_model(model, criterion, optimizer, dataloaders)
- #显示模型的输出结果
- def test_gpu():
- model = Unet(3, 1).to(device)
- # model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
- model.load_state_dict(torch.load(args.ckp, map_location='cuda'))
- liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
- dataloaders = DataLoader(liver_dataset, batch_size=1)
- model.cuda(0)
- model.eval()
- import matplotlib.pyplot as plt
- plt.ion()
- with torch.no_grad():
- for x, _ in dataloaders:
- y=model(x.to(device))
- img_y=torch.squeeze(y).cpu().numpy()
- plt.imshow(img_y)
- plt.pause(0.01)
-
- torch.cuda.empty_cache()
-
- cv2.imwrite('./resultsImages/'+str(random.random())+'.tiff', img_y)
- plt.show()
- torch.cuda.empty_cache()
-
- def test_cpu():
- model = Unet(3, 1)
- model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
- liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
- dataloaders = DataLoader(liver_dataset, batch_size=1)
- model.eval()
- import matplotlib.pyplot as plt
- plt.ion()
- with torch.no_grad():
- for x, _ in dataloaders:
- y=model(x)
- img_y=torch.squeeze(y).numpy()
- plt.imshow(img_y)
- plt.pause(0.01)
- plt.show()
-
- torch.cuda.empty_cache()
-
- if __name__ == '__main__':
- parse = argparse.ArgumentParser()
- # parse.add_argument("action", type=str, help="train or test")
- parse.add_argument("--action", type=str, help="train or test", default="train")
- parse.add_argument("--batch_size", type=int, default=4)
- # parse.add_argument("--ckp", type=str, help="the path of model weight file")
- parse.add_argument("--ckp", type=str, help="the path of model weight file", default="weights_19.pth")
- args = parse.parse_args()
- print(args.action)
- if args.action=="train":
- train()
- elif args.action=="test":
- test_gpu()
-
- test_gpu()
-
复制代码 尤其要注意的是:
del inputs,labels,outputs
torch.cuda.empty_cache()
清理GPU缓存,不然报GPU内存不够错误。
unet.py如下:
- import torch.nn as nn
- import torch
- from torch import autograd
- class DoubleConv(nn.Module):
- def __init__(self, in_ch, out_ch):
- super(DoubleConv, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True)
- )
- def forward(self, input):
- return self.conv(input)
- class Unet(nn.Module):
- def __init__(self,in_ch,out_ch):
- super(Unet, self).__init__()
- self.conv1 = DoubleConv(in_ch, 64)
- self.pool1 = nn.MaxPool2d(2)
- self.conv2 = DoubleConv(64, 128)
- self.pool2 = nn.MaxPool2d(2)
- self.conv3 = DoubleConv(128, 256)
- self.pool3 = nn.MaxPool2d(2)
- self.conv4 = DoubleConv(256, 512)
- self.pool4 = nn.MaxPool2d(2)
- self.conv5 = DoubleConv(512, 1024)
- self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
- self.conv6 = DoubleConv(1024, 512)
- self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
- self.conv7 = DoubleConv(512, 256)
- self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
- self.conv8 = DoubleConv(256, 128)
- self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
- self.conv9 = DoubleConv(128, 64)
- self.conv10 = nn.Conv2d(64,out_ch, 1)
- def forward(self,x):
- c1=self.conv1(x)
- p1=self.pool1(c1)
- c2=self.conv2(p1)
- p2=self.pool2(c2)
- c3=self.conv3(p2)
- p3=self.pool3(c3)
- c4=self.conv4(p3)
- p4=self.pool4(c4)
- c5=self.conv5(p4)
- up_6= self.up6(c5)
- merge6 = torch.cat([up_6, c4], dim=1)
- c6=self.conv6(merge6)
- up_7=self.up7(c6)
- merge7 = torch.cat([up_7, c3], dim=1)
- c7=self.conv7(merge7)
- up_8=self.up8(c7)
- merge8 = torch.cat([up_8, c2], dim=1)
- c8=self.conv8(merge8)
- up_9=self.up9(c8)
- merge9=torch.cat([up_9,c1],dim=1)
- c9=self.conv9(merge9)
- c10=self.conv10(c9)
- out = nn.Sigmoid()(c10)
- return out
复制代码
参考:
【1】读取文件夹内图像--方法1
【2】读取文件夹内图像--方法2
【3】https://github.com/gupta-abhay/pytorch-modelzoo
【4】https://github.com/yt4766269/pytorch_zoo
【5】https://github.com/takahiro-itazuri/model-zoo-pytorch
【6】http://222.195.93.137/gitlab/winston.wen/kaggle-1
【7】3-6之百度网盘下载链接:https://pan.baidu.com/s/17pXf2M3lIAFkJuoMXOZ9pA 提取码:k0ou
|
|