GTX_AI 发表于 2020-11-18 21:54:47

Pytorch图像标签分类之EfficientNet-PyTorch

Pytorch图像标签分类之EfficientNet-PyTorch:
# -*- coding: utf-8 -*-
"""
Created on Sun Feb5 14:56:26 2023
@author: halcom
"""
import sys
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch import autograd, optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torchvision.transforms import transforms
from functools import partial
from xml.dom.minidom import parse
import cv2
import time
import shutil
import copy
from efficientnet_pytorch import EfficientNet
#
# 将字符串转化为数字标签,序号从1开始,1-N
class strLabelConverter(object):
    """Convert between str and label.
    NOTE:
      Insert `blank` to the alphabet for CTC.
    Args:
      alphabet (str): set of the possible characters.
      ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=False):
      self._ignore_case = ignore_case
      if self._ignore_case:
            alphabet = alphabet.lower()
      self.alphabet = alphabet + '-'# for `-1` index

      self.dict = {}
      for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict = i + 1

    def encode(self, text):
      """Support batch or single str.
      Args:
            text (str or list of str): texts to convert.
      Returns:
            torch.IntTensor : encoded texts.
            torch.IntTensor : length of each text.
      """

      length = []
      result = []
      decode_flag = True if type(text)==bytes else False

      for item in text:

            if decode_flag:
                item = item.decode('utf-8','strict')
            length.append(len(item))
            for char in item:
                index = self.dict
                result.append(index)
      text = result
      return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
      """Decode encoded texts back into strs.
      Args:
            torch.IntTensor : encoded texts.
            torch.IntTensor : length of each text.
      Raises:
            AssertionError: when the texts and its length does not match.
      Returns:
            text (str or list of str): texts to convert.
      """
      if length.numel() == 1:
            length = length
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
            if raw:
                return ''.join( for i in t])
            else:
                char_list = []
                for i in range(length):
                  if t != 0 and (not (i > 0 and t == t)):
                        char_list.append(self.alphabet - 1])
                return ''.join(char_list)
      else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length
                texts.append(
                  self.decode(
                        t, torch.IntTensor(), raw=raw))
                index += l
            return texts
#
# 测试函数strLabelConverter
def AlphaConverter(AlphaBet):
    AlphaBet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
    converter = strLabelConverter(AlphaBet)
    labels = '0134ABC' # 注意 CrossEntropyLoss, label需要从0到N-1
    text_train, length_train = converter.encode(labels)
    # calc correct ratio
    preds = torch.tensor()
    preds_size = torch.IntTensor()
    sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
    # sim_preds对应的就是AlphaBet
    return converter
#
def make_dataset_train_txt(root, txt_file, train_image_format):
    imgs=[]
    with open(txt_file) as f:
      indexs=f.readlines()
    for i in range( len(indexs) ):
#      index = str( indexs ).split('.')
      index = str( indexs ).split('\n')
      # img=os.path.join(root, "train", index+".png")
      img=os.path.join(root, "Image", index + train_image_format)
      mask=os.path.join(root, "Mask", index + ".png")
      imgs.append((img,mask))
    return imgs
def make_dataset_val_txt(root, txt_file, train_image_format):
    imgs=[]
    with open(txt_file) as f:
      indexs=f.readlines()
    for i in range( len(indexs) ):
#      index = str( indexs ).split('.')
      index = str( indexs ).split('\n')
      # img=os.path.join(root, "val", index+".png")
      img=os.path.join(root, "Image", index + train_image_format)
      mask=os.path.join(root, "Mask", index + ".png")
      imgs.append((img,mask))
    return imgs
class train_LiverDataset(data.Dataset):
    def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
      imgs = make_dataset_train_txt(root, txt_file, train_image_format)
      self.imgs = imgs
      self.transform = transform
      self.target_transform = target_transform
      self.classNum = classNum
      self.train_image_format = train_image_format
      self.img_channel = img_channel
    def __getitem__(self, index):
      x_path, y_path = self.imgs
      if(self.img_channel == 3):
            image = cv2.imread(x_path, 1)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      elif(self.img_channel == 1):
            image = cv2.imread(x_path, 0)
      else:
            return
      mask = cv2.imread(y_path, 0)
      # extract certain classes from mask

      mask = np.stack(masks, axis=-1).astype('float')
      
      if self.transform is not None:
            image = self.transform(image)
      if self.target_transform is not None:
            mask = self.target_transform(mask)
      return image, mask
    def __len__(self):
      return len(self.imgs)
class val_LiverDataset(data.Dataset):
    def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
      imgs = make_dataset_val_txt(root, txt_file, train_image_format)
      self.imgs = imgs
      self.transform = transform
      self.target_transform = target_transform
      self.classNum = classNum
      self.train_image_format = train_image_format
      self.img_channel = img_channel
    def __getitem__(self, index):
      x_path, y_path = self.imgs
      if(self.img_channel == 3):
            image = cv2.imread(x_path, 1)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      elif(self.img_channel == 1):
            image = cv2.imread(x_path, 0)
      else:
            return
      mask = cv2.imread(y_path, 0)
      # extract certain classes from mask

      mask = np.stack(masks, axis=-1).astype('float')
      
      if self.transform is not None:
            image = self.transform(image)
      if self.target_transform is not None:
            mask = self.target_transform(mask)
      return image, mask
    def __len__(self):
      return len(self.imgs)
# training model
def trainAI(train_project_path, xmlfile, img_root_path, mask_root_path, train_txt_file_path, train_loss_txt_file, val_loss_txt_file, trainbatch_loss_txt_file, train_image_format, initial_model, GPU_index, img_channel, img_size, batch_size, num_workers, num_epochs, Seg_p, lr_init, moment_init, classNum, ENCODER, ENCODER_WEIGHTS, AlphaBet):
    # 是否使用cuda
    device = torch.device("cuda:" + str(GPU_index) if torch.cuda.is_available() else "cpu")
    classNum = len(AlphaBet)
    model = EfficientNet.from_pretrained('efficientnet-b0', in_channels=img_channel, num_classes= len(AlphaBet)).to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam()
    if img_channel == 1:
      x_transforms = transforms.Compose(, ) ])
    elif img_channel == 3:
      x_transforms = transforms.Compose(, ) ])
      # x_transforms = transforms.Compose(, ) ])
    else:
      return
    y_transforms = transforms.ToTensor()
    train_liver_dataset = train_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_train.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
    train_dataloaders = DataLoader(train_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
    val_liver_dataset = val_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_validation.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
    val_dataloaders = DataLoader(val_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
    if initial_model != "None":
      model = torch.jit.load(initial_model)
    model.cuda(GPU_index)
    model.train()
    train_sum_loss = np.inf
    val_sum_loss = np.inf
    for epoch in range(num_epochs):
      print('Epoch {}/{}'.format(epoch, num_epochs - 1))
      dt_size = len(train_dataloaders.dataset)
      step = 0
      epoch_loss = 0
      model.train()
      for x, y in train_dataloaders:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss_dict = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            print("%d/%d,train_loss:%0.3f,train_metric:%0.3f" % (step, (dt_size - 1) // train_dataloaders.batch_size + 1, loss.item(), epoch_metric))
            del x, y, inputs, labels, outputs
            torch.cuda.empty_cache()
      # print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
      if(epoch_loss < train_sum_loss):
            train_sum_loss = epoch_loss
            model.eval()
            example = torch.rand(1, img_channel, img_size, img_size).type(torch.FloatTensor).cuda(GPU_index)
            pt_model = torch.jit.trace(model, example)
            pt_model.save( os.path.join(train_project_path, "Model", str(epoch) +'-Loss.pt') )
参考:
【1】https://github.com/lukemelas/EfficientNet-PyTorch
【2】百度网盘连接:链接:https://pan.baidu.com/s/1dVcTded9PD-8Oy270bAGhA 提取码:u60r
【3】Efficientnet_pytorch训练⾃⼰的数据集,并对数据进⾏分类
【4】【图像分类】——来来来,干了这碗EfficientNet实战(Pytorch)


Halcom 发表于 2023-2-5 20:05:44


efficientnet            1.1.1                  pypi_0    pypi
efficientnet-pytorch      0.6.3                  pypi_0    pypi

torch                     1.4.0                  pypi_0    pypi
torchvision               0.5.0                  pypi_0    pypi

python                  3.7.6                h60c2a47_2
python-dateutil         2.8.1                      py_0
python-jsonrpc-server   0.3.4                      py_0
python-language-server    0.31.7                   py37_0
python-libarchive-c       2.8                     py37_13

opencv-python             4.4.0.46               pypi_0    pypi
openpyxl                  3.0.3                      py_0
openssl                   1.1.1d               he774522_4
页: [1]
查看完整版本: Pytorch图像标签分类之EfficientNet-PyTorch