Hello Mat

 找回密码
 立即注册
查看: 6564|回复: 1

Vgg16预训练下的Unet

[复制链接]

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
发表于 2019-9-8 22:39:18 | 显示全部楼层 |阅读模式
Vgg16预训练下的Unet
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Aug 27 22:06:49 2019

  4. @author: Solem
  5. """
  6. import sys
  7. sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet")
  8. import os
  9. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  10. import numpy as np
  11. import torch
  12. import torch.utils.data as data
  13. from torch import autograd, optim
  14. from torchvision.transforms import transforms
  15. from torch.autograd import Variable
  16. # 配置文件
  17. from configs import configXML
  18. from pytorch_zoo import unet, resnet38unet
  19. # begin
  20. net = unet.Vgg16bn(num_classes=configXML.num_classes)
  21. net.train()
  22. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  23. #device = torch.device("cpu")
  24. #print(device)
  25. net = net.cuda()
  26. criterion1 =  torch.nn.BCELoss()
  27. if configXML.initial_checkpoints is not None:
  28.     net = torch.load(configXML.initial_checkpoints)
  29. best_val_loss = np.inf
  30. num_epoch_has_trained = 0
  31. for epoch in range(num_epoch_has_trained, configXML.num_epochs):
  32.     iter_count = 0;
  33.     train_losses = 0.0;
  34.     optimizer = optim.SGD(net.parameters(), lr = configXML.lr_init, momentum = configXML.moment_init)
  35.     torch.cuda.empty_cache()
  36.     net.train()
  37.     for i,(train_image,train_mask) in enumerate(configXML.train_loader):
  38.         iter_count = iter_count+1
  39.         train_image = Variable(train_image.cuda())
  40.         train_mask = Variable(train_mask.cuda())
  41.         train_logits = net(train_image)
  42.         train_loss = criterion1(train_logits, train_mask)
  43.         train_losses = train_losses + train_loss.item()
  44.         optimizer.zero_grad()
  45.         train_loss.backward()
  46.         optimizer.step()
  47.     torch.cuda.empty_cache()   
  48.     torch.save(net.state_dict(), './checkPoints/Vgg16_Val_normal.pth')
复制代码

configXML.py:
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Sep  8 20:47:37 2019

  4. @author: Solem
  5. """
  6. from XMLdataSet import dataset_improved, dataset_package
  7. import torch
  8. from torchvision.transforms import transforms
  9. from torch.utils.data import DataLoader

  10. Train_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_train.txt'
  11. Val_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_val.txt'
  12. Imagepath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images"
  13. Images_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images.txt'
  14. Maskspath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks"
  15. Masks_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks.txt'
  16. num_classes = 1
  17. img_size = 256
  18. batch_size = 4
  19. num_workers = 0

  20. initial_checkpoints = None
  21. #initial_checkpoints = r'D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet\checkPoints\Vgg16_Val_normal.pth'
  22. num_epochs = 1000
  23. lr_init = 0.01
  24. moment_init = 0.9

  25. # 均值0.5,方差0.5
  26. x_transforms = transforms.Compose([
  27.     transforms.ToTensor(),
  28.     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  29. ])

  30. # mask只需要转换为tensor
  31. y_transforms = transforms.ToTensor()

  32. #train_liver_dataset = dataset_improved.LiverDataset(imagepath, Train_file, transform=x_transforms, target_transform=y_transforms)
  33. #train_loader = DataLoader(train_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)

  34. train_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Train_file)
  35. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

  36. #val_liver_dataset = dataset_improved.LiverDataset(imagepath, Val_file, transform=x_transforms, target_transform=y_transforms)
  37. #val_loader = DataLoader(val_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)

  38. val_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Val_file)
  39. val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
复制代码
dataset_package.py如下:
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun May 19 22:28:08 2019

  4. @author: Solem
  5. """
  6. import os
  7. import PIL.Image as Image
  8. import numpy as np
  9. import cv2
  10. import torch
  11. import torch.utils.data as data
  12. from torch import autograd, optim
  13. from torchvision.transforms import transforms
  14. from torch.autograd import Variable
  15. from configs import configXML

  16. def make_dataset(root):
  17.     imgs=[]
  18. #    n=len(os.listdir(root))//2
  19. #    for i in range(100):
  20.     for i in range(100):
  21.         for j in range(16):
  22.             img=os.path.join(root,"%03d_%d.png"%(i,j+1))
  23.             mask=os.path.join(root,"%03d_%d_mask.png"%(i,j+1))
  24.             imgs.append((img,mask))
  25.     return imgs

  26. class LiverDataset(data.Dataset):
  27.     def __init__(self, root, transform=None, target_transform=None):
  28.         imgs = make_dataset(root)
  29.         self.imgs = imgs
  30.         self.transform = transform
  31.         self.target_transform = target_transform

  32.     def __getitem__(self, index):
  33.         x_path, y_path = self.imgs[index]
  34.         img_x = Image.open(x_path)
  35.         img_y = Image.open(y_path)
  36.         if self.transform is not None:
  37.             img_x = self.transform(img_x)
  38.         if self.target_transform is not None:
  39.             img_y = self.target_transform(img_y)
  40.         return img_x, img_y

  41.     def __len__(self):
  42.         return len(self.imgs)

  43. class opencvDateset(data.Dataset):
  44.     def __init__(self,img_root_path,mask_root_path, txt_file, transform=None):
  45.         self.transform = transform
  46.         self.img_root_path = img_root_path
  47.         self.mask_root_path = mask_root_path
  48.         with open(txt_file) as f:
  49.             self.indexs = f.readlines()
  50.     def __getitem__(self, idx):      
  51.         index = self.indexs[idx].split('.')[0]
  52.         image_path = os.path.join(self.img_root_path, index+'.jpg')
  53.         mask_path = os.path.join(self.mask_root_path, index+'_label.png')
  54.         img = cv2.imread(image_path)
  55.         img = cv2.resize(img, (configXML.img_size, configXML.img_size))
  56.         img = img.astype(np.float32)/255.0
  57.         mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  58. #        print(np.shape(mask))
  59.         mask = cv2.resize(mask, (configXML.img_size, configXML.img_size))
  60.         print(np.shape(mask))
  61.         if self.transform is not None:
  62.             for trans in self.transform:
  63.                 img, mask = trans(img, mask)
  64.         img = torch.FloatTensor(img)
  65.         mask = torch.FloatTensor(mask)
  66.         return img, mask
  67.    
  68.     def __len__(self):
  69.         return len(self.indexs)
复制代码

参考:
【1】Pytorch_Unet图像分割
【2】http://222.195.93.137/gitlab/winston.wen/kaggle-1














回复

使用道具 举报

1323

主题

1551

帖子

0

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22647
发表于 2019-9-9 22:22:24 | 显示全部楼层
dataset_package.py的第70行和第71行,应该可以这样修改,不然容易报错:
  1. def image_to_tensor(image, mean = 0.0, std = 1.0):
  2.         image = (image-mean)/std
  3.         image = image.transpose(2,0,1)
  4.         image = torch.from_numpy(image)
  5.         return image

  6. def mask_to_tensor(mask):
  7.         mask = (mask>128).astype(np.float32)
  8.         mask = torch.from_numpy(mask)
  9.         return mask
复制代码
第61行和第65行,是否显得多余。

针对主程序而言:
from configs import configXML
这样import,那么你确实需要configXML.initial_checkpoints,configXML.num_epochs来调用类的成员变量
可以改为
from configs.configXML import *
然后你使用的都是全局变量了,可以直接使用initial_checkpoints,num_epochs
记得每个文件夹放一个__init__.py文件。



















算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-26 06:09 , Processed in 0.213003 second(s), 21 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表