|
Vgg16预训练下的Unet
- # -*- coding: utf-8 -*-
- """
- Created on Tue Aug 27 22:06:49 2019
- @author: Solem
- """
- import sys
- sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet")
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- import numpy as np
- import torch
- import torch.utils.data as data
- from torch import autograd, optim
- from torchvision.transforms import transforms
- from torch.autograd import Variable
- # 配置文件
- from configs import configXML
- from pytorch_zoo import unet, resnet38unet
- # begin
- net = unet.Vgg16bn(num_classes=configXML.num_classes)
- net.train()
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- #device = torch.device("cpu")
- #print(device)
- net = net.cuda()
- criterion1 = torch.nn.BCELoss()
- if configXML.initial_checkpoints is not None:
- net = torch.load(configXML.initial_checkpoints)
- best_val_loss = np.inf
- num_epoch_has_trained = 0
- for epoch in range(num_epoch_has_trained, configXML.num_epochs):
- iter_count = 0;
- train_losses = 0.0;
- optimizer = optim.SGD(net.parameters(), lr = configXML.lr_init, momentum = configXML.moment_init)
- torch.cuda.empty_cache()
- net.train()
- for i,(train_image,train_mask) in enumerate(configXML.train_loader):
- iter_count = iter_count+1
- train_image = Variable(train_image.cuda())
- train_mask = Variable(train_mask.cuda())
- train_logits = net(train_image)
- train_loss = criterion1(train_logits, train_mask)
- train_losses = train_losses + train_loss.item()
- optimizer.zero_grad()
- train_loss.backward()
- optimizer.step()
- torch.cuda.empty_cache()
- torch.save(net.state_dict(), './checkPoints/Vgg16_Val_normal.pth')
复制代码
configXML.py:
- # -*- coding: utf-8 -*-
- """
- Created on Sun Sep 8 20:47:37 2019
- @author: Solem
- """
- from XMLdataSet import dataset_improved, dataset_package
- import torch
- from torchvision.transforms import transforms
- from torch.utils.data import DataLoader
- Train_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_train.txt'
- Val_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_val.txt'
- Imagepath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images"
- Images_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images.txt'
- Maskspath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks"
- Masks_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks.txt'
- num_classes = 1
- img_size = 256
- batch_size = 4
- num_workers = 0
- initial_checkpoints = None
- #initial_checkpoints = r'D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet\checkPoints\Vgg16_Val_normal.pth'
- num_epochs = 1000
- lr_init = 0.01
- moment_init = 0.9
- # 均值0.5,方差0.5
- x_transforms = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
- ])
- # mask只需要转换为tensor
- y_transforms = transforms.ToTensor()
- #train_liver_dataset = dataset_improved.LiverDataset(imagepath, Train_file, transform=x_transforms, target_transform=y_transforms)
- #train_loader = DataLoader(train_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
- train_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Train_file)
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
- #val_liver_dataset = dataset_improved.LiverDataset(imagepath, Val_file, transform=x_transforms, target_transform=y_transforms)
- #val_loader = DataLoader(val_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
- val_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Val_file)
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
复制代码 dataset_package.py如下:
- # -*- coding: utf-8 -*-
- """
- Created on Sun May 19 22:28:08 2019
- @author: Solem
- """
- import os
- import PIL.Image as Image
- import numpy as np
- import cv2
- import torch
- import torch.utils.data as data
- from torch import autograd, optim
- from torchvision.transforms import transforms
- from torch.autograd import Variable
- from configs import configXML
- def make_dataset(root):
- imgs=[]
- # n=len(os.listdir(root))//2
- # for i in range(100):
- for i in range(100):
- for j in range(16):
- img=os.path.join(root,"%03d_%d.png"%(i,j+1))
- mask=os.path.join(root,"%03d_%d_mask.png"%(i,j+1))
- imgs.append((img,mask))
- return imgs
- class LiverDataset(data.Dataset):
- def __init__(self, root, transform=None, target_transform=None):
- imgs = make_dataset(root)
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
- def __getitem__(self, index):
- x_path, y_path = self.imgs[index]
- img_x = Image.open(x_path)
- img_y = Image.open(y_path)
- if self.transform is not None:
- img_x = self.transform(img_x)
- if self.target_transform is not None:
- img_y = self.target_transform(img_y)
- return img_x, img_y
- def __len__(self):
- return len(self.imgs)
- class opencvDateset(data.Dataset):
- def __init__(self,img_root_path,mask_root_path, txt_file, transform=None):
- self.transform = transform
- self.img_root_path = img_root_path
- self.mask_root_path = mask_root_path
- with open(txt_file) as f:
- self.indexs = f.readlines()
- def __getitem__(self, idx):
- index = self.indexs[idx].split('.')[0]
- image_path = os.path.join(self.img_root_path, index+'.jpg')
- mask_path = os.path.join(self.mask_root_path, index+'_label.png')
- img = cv2.imread(image_path)
- img = cv2.resize(img, (configXML.img_size, configXML.img_size))
- img = img.astype(np.float32)/255.0
- mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
- # print(np.shape(mask))
- mask = cv2.resize(mask, (configXML.img_size, configXML.img_size))
- print(np.shape(mask))
- if self.transform is not None:
- for trans in self.transform:
- img, mask = trans(img, mask)
- img = torch.FloatTensor(img)
- mask = torch.FloatTensor(mask)
- return img, mask
-
- def __len__(self):
- return len(self.indexs)
复制代码
参考:
【1】Pytorch_Unet图像分割
【2】http://222.195.93.137/gitlab/winston.wen/kaggle-1
|
|