Pytorch图像标签分类之EfficientNet-PyTorch
Pytorch图像标签分类之EfficientNet-PyTorch:# -*- coding: utf-8 -*-
"""
Created on Sun Feb5 14:56:26 2023
@author: halcom
"""
import sys
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch import autograd, optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torchvision.transforms import transforms
from functools import partial
from xml.dom.minidom import parse
import cv2
import time
import shutil
import copy
from efficientnet_pytorch import EfficientNet
#
# 将字符串转化为数字标签,序号从1开始,1-N
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-'# for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor : encoded texts.
torch.IntTensor : length of each text.
"""
length = []
result = []
decode_flag = True if type(text)==bytes else False
for item in text:
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
for char in item:
index = self.dict
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor : encoded texts.
torch.IntTensor : length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join( for i in t])
else:
char_list = []
for i in range(length):
if t != 0 and (not (i > 0 and t == t)):
char_list.append(self.alphabet - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length
texts.append(
self.decode(
t, torch.IntTensor(), raw=raw))
index += l
return texts
#
# 测试函数strLabelConverter
def AlphaConverter(AlphaBet):
AlphaBet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
converter = strLabelConverter(AlphaBet)
labels = '0134ABC' # 注意 CrossEntropyLoss, label需要从0到N-1
text_train, length_train = converter.encode(labels)
# calc correct ratio
preds = torch.tensor()
preds_size = torch.IntTensor()
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
# sim_preds对应的就是AlphaBet
return converter
#
def make_dataset_train_txt(root, txt_file, train_image_format):
imgs=[]
with open(txt_file) as f:
indexs=f.readlines()
for i in range( len(indexs) ):
# index = str( indexs ).split('.')
index = str( indexs ).split('\n')
# img=os.path.join(root, "train", index+".png")
img=os.path.join(root, "Image", index + train_image_format)
mask=os.path.join(root, "Mask", index + ".png")
imgs.append((img,mask))
return imgs
def make_dataset_val_txt(root, txt_file, train_image_format):
imgs=[]
with open(txt_file) as f:
indexs=f.readlines()
for i in range( len(indexs) ):
# index = str( indexs ).split('.')
index = str( indexs ).split('\n')
# img=os.path.join(root, "val", index+".png")
img=os.path.join(root, "Image", index + train_image_format)
mask=os.path.join(root, "Mask", index + ".png")
imgs.append((img,mask))
return imgs
class train_LiverDataset(data.Dataset):
def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
imgs = make_dataset_train_txt(root, txt_file, train_image_format)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.classNum = classNum
self.train_image_format = train_image_format
self.img_channel = img_channel
def __getitem__(self, index):
x_path, y_path = self.imgs
if(self.img_channel == 3):
image = cv2.imread(x_path, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif(self.img_channel == 1):
image = cv2.imread(x_path, 0)
else:
return
mask = cv2.imread(y_path, 0)
# extract certain classes from mask
mask = np.stack(masks, axis=-1).astype('float')
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
mask = self.target_transform(mask)
return image, mask
def __len__(self):
return len(self.imgs)
class val_LiverDataset(data.Dataset):
def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
imgs = make_dataset_val_txt(root, txt_file, train_image_format)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.classNum = classNum
self.train_image_format = train_image_format
self.img_channel = img_channel
def __getitem__(self, index):
x_path, y_path = self.imgs
if(self.img_channel == 3):
image = cv2.imread(x_path, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif(self.img_channel == 1):
image = cv2.imread(x_path, 0)
else:
return
mask = cv2.imread(y_path, 0)
# extract certain classes from mask
mask = np.stack(masks, axis=-1).astype('float')
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
mask = self.target_transform(mask)
return image, mask
def __len__(self):
return len(self.imgs)
# training model
def trainAI(train_project_path, xmlfile, img_root_path, mask_root_path, train_txt_file_path, train_loss_txt_file, val_loss_txt_file, trainbatch_loss_txt_file, train_image_format, initial_model, GPU_index, img_channel, img_size, batch_size, num_workers, num_epochs, Seg_p, lr_init, moment_init, classNum, ENCODER, ENCODER_WEIGHTS, AlphaBet):
# 是否使用cuda
device = torch.device("cuda:" + str(GPU_index) if torch.cuda.is_available() else "cpu")
classNum = len(AlphaBet)
model = EfficientNet.from_pretrained('efficientnet-b0', in_channels=img_channel, num_classes= len(AlphaBet)).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam()
if img_channel == 1:
x_transforms = transforms.Compose(, ) ])
elif img_channel == 3:
x_transforms = transforms.Compose(, ) ])
# x_transforms = transforms.Compose(, ) ])
else:
return
y_transforms = transforms.ToTensor()
train_liver_dataset = train_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_train.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
train_dataloaders = DataLoader(train_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
val_liver_dataset = val_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_validation.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
val_dataloaders = DataLoader(val_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
if initial_model != "None":
model = torch.jit.load(initial_model)
model.cuda(GPU_index)
model.train()
train_sum_loss = np.inf
val_sum_loss = np.inf
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
dt_size = len(train_dataloaders.dataset)
step = 0
epoch_loss = 0
model.train()
for x, y in train_dataloaders:
step += 1
inputs = x.to(device)
labels = y.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss_dict = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f,train_metric:%0.3f" % (step, (dt_size - 1) // train_dataloaders.batch_size + 1, loss.item(), epoch_metric))
del x, y, inputs, labels, outputs
torch.cuda.empty_cache()
# print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
if(epoch_loss < train_sum_loss):
train_sum_loss = epoch_loss
model.eval()
example = torch.rand(1, img_channel, img_size, img_size).type(torch.FloatTensor).cuda(GPU_index)
pt_model = torch.jit.trace(model, example)
pt_model.save( os.path.join(train_project_path, "Model", str(epoch) +'-Loss.pt') )
参考:
【1】https://github.com/lukemelas/EfficientNet-PyTorch
【2】百度网盘连接:链接:https://pan.baidu.com/s/1dVcTded9PD-8Oy270bAGhA 提取码:u60r
【3】Efficientnet_pytorch训练⾃⼰的数据集,并对数据进⾏分类
【4】【图像分类】——来来来,干了这碗EfficientNet实战(Pytorch)
efficientnet 1.1.1 pypi_0 pypi
efficientnet-pytorch 0.6.3 pypi_0 pypi
torch 1.4.0 pypi_0 pypi
torchvision 0.5.0 pypi_0 pypi
python 3.7.6 h60c2a47_2
python-dateutil 2.8.1 py_0
python-jsonrpc-server 0.3.4 py_0
python-language-server 0.31.7 py37_0
python-libarchive-c 2.8 py37_13
opencv-python 4.4.0.46 pypi_0 pypi
openpyxl 3.0.3 py_0
openssl 1.1.1d he774522_4
页:
[1]