|
Pytorch图像标签分类之EfficientNet-PyTorch:
- # -*- coding: utf-8 -*-
- """
- Created on Sun Feb 5 14:56:26 2023
- @author: halcom
- """
- import sys
- import os
- # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- import numpy as np
- import torch
- import torch.utils.data as data
- from torch.utils.data import DataLoader
- from torch import autograd, optim
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.modules.loss import _Loss
- from torchvision.transforms import transforms
- from functools import partial
- from xml.dom.minidom import parse
- import cv2
- import time
- import shutil
- import copy
- from efficientnet_pytorch import EfficientNet
- #
- # 将字符串转化为数字标签,序号从1开始,1-N
- class strLabelConverter(object):
- """Convert between str and label.
- NOTE:
- Insert `blank` to the alphabet for CTC.
- Args:
- alphabet (str): set of the possible characters.
- ignore_case (bool, default=True): whether or not to ignore all of the case.
- """
- def __init__(self, alphabet, ignore_case=False):
- self._ignore_case = ignore_case
- if self._ignore_case:
- alphabet = alphabet.lower()
- self.alphabet = alphabet + '-' # for `-1` index
- self.dict = {}
- for i, char in enumerate(alphabet):
- # NOTE: 0 is reserved for 'blank' required by wrap_ctc
- self.dict[char] = i + 1
- def encode(self, text):
- """Support batch or single str.
- Args:
- text (str or list of str): texts to convert.
- Returns:
- torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
- torch.IntTensor [n]: length of each text.
- """
- length = []
- result = []
- decode_flag = True if type(text[0])==bytes else False
- for item in text:
- if decode_flag:
- item = item.decode('utf-8','strict')
- length.append(len(item))
- for char in item:
- index = self.dict[char]
- result.append(index)
- text = result
- return (torch.IntTensor(text), torch.IntTensor(length))
- def decode(self, t, length, raw=False):
- """Decode encoded texts back into strs.
- Args:
- torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
- torch.IntTensor [n]: length of each text.
- Raises:
- AssertionError: when the texts and its length does not match.
- Returns:
- text (str or list of str): texts to convert.
- """
- if length.numel() == 1:
- length = length[0]
- assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
- if raw:
- return ''.join([self.alphabet[i - 1] for i in t])
- else:
- char_list = []
- for i in range(length):
- if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
- char_list.append(self.alphabet[t[i] - 1])
- return ''.join(char_list)
- else:
- # batch mode
- assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
- texts = []
- index = 0
- for i in range(length.numel()):
- l = length[i]
- texts.append(
- self.decode(
- t[index:index + l], torch.IntTensor([l]), raw=raw))
- index += l
- return texts
- #
- # 测试函数strLabelConverter
- def AlphaConverter(AlphaBet):
- AlphaBet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
- converter = strLabelConverter(AlphaBet)
- labels = '0134ABC' # 注意 CrossEntropyLoss, label需要从0到N-1
- text_train, length_train = converter.encode(labels)
- # calc correct ratio
- preds = torch.tensor([1, 3, 4, 11])
- preds_size = torch.IntTensor([preds.size(0)])
- sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
- # sim_preds对应的就是AlphaBet
- return converter
- #
- def make_dataset_train_txt(root, txt_file, train_image_format):
- imgs=[]
- with open(txt_file) as f:
- indexs=f.readlines()
- for i in range( len(indexs) ):
- # index = str( indexs[i] ).split('.')[0]
- index = str( indexs[i] ).split('\n')[0]
- # img=os.path.join(root, "train", index+".png")
- img=os.path.join(root, "Image", index + train_image_format)
- mask=os.path.join(root, "Mask", index + ".png")
- imgs.append((img,mask))
- return imgs
- def make_dataset_val_txt(root, txt_file, train_image_format):
- imgs=[]
- with open(txt_file) as f:
- indexs=f.readlines()
- for i in range( len(indexs) ):
- # index = str( indexs[i] ).split('.')[0]
- index = str( indexs[i] ).split('\n')[0]
- # img=os.path.join(root, "val", index+".png")
- img=os.path.join(root, "Image", index + train_image_format)
- mask=os.path.join(root, "Mask", index + ".png")
- imgs.append((img,mask))
- return imgs
- class train_LiverDataset(data.Dataset):
- def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
- imgs = make_dataset_train_txt(root, txt_file, train_image_format)
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
- self.classNum = classNum
- self.train_image_format = train_image_format
- self.img_channel = img_channel
- def __getitem__(self, index):
- x_path, y_path = self.imgs[index]
- if(self.img_channel == 3):
- image = cv2.imread(x_path, 1)
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- elif(self.img_channel == 1):
- image = cv2.imread(x_path, 0)
- else:
- return
- mask = cv2.imread(y_path, 0)
- # extract certain classes from mask
- mask = np.stack(masks, axis=-1).astype('float')
-
- if self.transform is not None:
- image = self.transform(image)
- if self.target_transform is not None:
- mask = self.target_transform(mask)
- return image, mask
- def __len__(self):
- return len(self.imgs)
- class val_LiverDataset(data.Dataset):
- def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
- imgs = make_dataset_val_txt(root, txt_file, train_image_format)
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
- self.classNum = classNum
- self.train_image_format = train_image_format
- self.img_channel = img_channel
- def __getitem__(self, index):
- x_path, y_path = self.imgs[index]
- if(self.img_channel == 3):
- image = cv2.imread(x_path, 1)
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- elif(self.img_channel == 1):
- image = cv2.imread(x_path, 0)
- else:
- return
- mask = cv2.imread(y_path, 0)
- # extract certain classes from mask
- mask = np.stack(masks, axis=-1).astype('float')
-
- if self.transform is not None:
- image = self.transform(image)
- if self.target_transform is not None:
- mask = self.target_transform(mask)
- return image, mask
- def __len__(self):
- return len(self.imgs)
- # training model
- def trainAI(train_project_path, xmlfile, img_root_path, mask_root_path, train_txt_file_path, train_loss_txt_file, val_loss_txt_file, trainbatch_loss_txt_file, train_image_format, initial_model, GPU_index, img_channel, img_size, batch_size, num_workers, num_epochs, Seg_p, lr_init, moment_init, classNum, ENCODER, ENCODER_WEIGHTS, AlphaBet):
- # 是否使用cuda
- device = torch.device("cuda:" + str(GPU_index) if torch.cuda.is_available() else "cpu")
- classNum = len(AlphaBet)
- model = EfficientNet.from_pretrained('efficientnet-b0', in_channels=img_channel, num_classes= len(AlphaBet)).to(device)
- criterion = torch.nn.CrossEntropyLoss().to(device)
- optimizer = torch.optim.Adam([dict(params=model.parameters(), lr = lr_init)])
- if img_channel == 1:
- x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ])
- elif img_channel == 3:
- x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])
- # x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
- else:
- return
- y_transforms = transforms.ToTensor()
- train_liver_dataset = train_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_train.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
- train_dataloaders = DataLoader(train_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
- val_liver_dataset = val_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_validation.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
- val_dataloaders = DataLoader(val_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
- if initial_model != "None":
- model = torch.jit.load(initial_model)
- model.cuda(GPU_index)
- model.train()
- train_sum_loss = np.inf
- val_sum_loss = np.inf
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- dt_size = len(train_dataloaders.dataset)
- step = 0
- epoch_loss = 0
- model.train()
- for x, y in train_dataloaders:
- step += 1
- inputs = x.to(device)
- labels = y.to(device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss_dict = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- print("%d/%d,train_loss:%0.3f,train_metric:%0.3f" % (step, (dt_size - 1) // train_dataloaders.batch_size + 1, loss.item(), epoch_metric))
- del x, y, inputs, labels, outputs
- torch.cuda.empty_cache()
- # print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
- if(epoch_loss < train_sum_loss):
- train_sum_loss = epoch_loss
- model.eval()
- example = torch.rand(1, img_channel, img_size, img_size).type(torch.FloatTensor).cuda(GPU_index)
- pt_model = torch.jit.trace(model, example)
- pt_model.save( os.path.join(train_project_path, "Model", str(epoch) +'-Loss.pt') )
复制代码
参考:
【1】https://github.com/lukemelas/EfficientNet-PyTorch
【2】百度网盘连接:链接:https://pan.baidu.com/s/1dVcTded9PD-8Oy270bAGhA 提取码:u60r
【3】Efficientnet_pytorch训练⾃⼰的数据集,并对数据进⾏分类
【4】【图像分类】——来来来,干了这碗EfficientNet实战(Pytorch)
|
|