请选择 进入手机版 | 继续访问电脑版

Hello Mat

 找回密码
 立即注册
查看: 4714|回复: 4

Pytorch图像标签分类之EfficientNet-PyTorch

[复制链接]

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
发表于 2020-11-18 21:54:47 | 显示全部楼层 |阅读模式
Pytorch图像标签分类之EfficientNet-PyTorch:
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Feb  5 14:56:26 2023
  4. @author: halcom
  5. """
  6. import sys
  7. import os
  8. # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  9. import numpy as np
  10. import torch
  11. import torch.utils.data as data
  12. from torch.utils.data import DataLoader
  13. from torch import autograd, optim
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from torch.nn.modules.loss import _Loss
  17. from torchvision.transforms import transforms
  18. from functools import partial
  19. from xml.dom.minidom import parse
  20. import cv2
  21. import time
  22. import shutil
  23. import copy
  24. from efficientnet_pytorch import EfficientNet
  25. #
  26. # 将字符串转化为数字标签,序号从1开始,1-N
  27. class strLabelConverter(object):
  28.     """Convert between str and label.
  29.     NOTE:
  30.         Insert `blank` to the alphabet for CTC.
  31.     Args:
  32.         alphabet (str): set of the possible characters.
  33.         ignore_case (bool, default=True): whether or not to ignore all of the case.
  34.     """

  35.     def __init__(self, alphabet, ignore_case=False):
  36.         self._ignore_case = ignore_case
  37.         if self._ignore_case:
  38.             alphabet = alphabet.lower()
  39.         self.alphabet = alphabet + '-'  # for `-1` index

  40.         self.dict = {}
  41.         for i, char in enumerate(alphabet):
  42.             # NOTE: 0 is reserved for 'blank' required by wrap_ctc
  43.             self.dict[char] = i + 1

  44.     def encode(self, text):
  45.         """Support batch or single str.
  46.         Args:
  47.             text (str or list of str): texts to convert.
  48.         Returns:
  49.             torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  50.             torch.IntTensor [n]: length of each text.
  51.         """

  52.         length = []
  53.         result = []
  54.         decode_flag = True if type(text[0])==bytes else False

  55.         for item in text:

  56.             if decode_flag:
  57.                 item = item.decode('utf-8','strict')
  58.             length.append(len(item))
  59.             for char in item:
  60.                 index = self.dict[char]
  61.                 result.append(index)
  62.         text = result
  63.         return (torch.IntTensor(text), torch.IntTensor(length))

  64.     def decode(self, t, length, raw=False):
  65.         """Decode encoded texts back into strs.
  66.         Args:
  67.             torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  68.             torch.IntTensor [n]: length of each text.
  69.         Raises:
  70.             AssertionError: when the texts and its length does not match.
  71.         Returns:
  72.             text (str or list of str): texts to convert.
  73.         """
  74.         if length.numel() == 1:
  75.             length = length[0]
  76.             assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
  77.             if raw:
  78.                 return ''.join([self.alphabet[i - 1] for i in t])
  79.             else:
  80.                 char_list = []
  81.                 for i in range(length):
  82.                     if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
  83.                         char_list.append(self.alphabet[t[i] - 1])
  84.                 return ''.join(char_list)
  85.         else:
  86.             # batch mode
  87.             assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
  88.             texts = []
  89.             index = 0
  90.             for i in range(length.numel()):
  91.                 l = length[i]
  92.                 texts.append(
  93.                     self.decode(
  94.                         t[index:index + l], torch.IntTensor([l]), raw=raw))
  95.                 index += l
  96.             return texts
  97. #
  98. # 测试函数strLabelConverter
  99. def AlphaConverter(AlphaBet):
  100.     AlphaBet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
  101.     converter = strLabelConverter(AlphaBet)
  102.     labels = '0134ABC' # 注意 CrossEntropyLoss, label需要从0到N-1
  103.     text_train, length_train = converter.encode(labels)
  104.     # calc correct ratio
  105.     preds = torch.tensor([1, 3, 4, 11])
  106.     preds_size = torch.IntTensor([preds.size(0)])
  107.     sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
  108.     # sim_preds对应的就是AlphaBet
  109.     return converter
  110. #
  111. def make_dataset_train_txt(root, txt_file, train_image_format):
  112.     imgs=[]
  113.     with open(txt_file) as f:
  114.         indexs=f.readlines()
  115.     for i in range( len(indexs) ):
  116. #        index = str( indexs[i] ).split('.')[0]
  117.         index = str( indexs[i] ).split('\n')[0]
  118.         # img=os.path.join(root, "train", index+".png")
  119.         img=os.path.join(root, "Image", index + train_image_format)
  120.         mask=os.path.join(root, "Mask", index + ".png")
  121.         imgs.append((img,mask))
  122.     return imgs
  123. def make_dataset_val_txt(root, txt_file, train_image_format):
  124.     imgs=[]
  125.     with open(txt_file) as f:
  126.         indexs=f.readlines()
  127.     for i in range( len(indexs) ):
  128. #        index = str( indexs[i] ).split('.')[0]
  129.         index = str( indexs[i] ).split('\n')[0]
  130.         # img=os.path.join(root, "val", index+".png")
  131.         img=os.path.join(root, "Image", index + train_image_format)
  132.         mask=os.path.join(root, "Mask", index + ".png")
  133.         imgs.append((img,mask))
  134.     return imgs
  135. class train_LiverDataset(data.Dataset):
  136.     def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
  137.         imgs = make_dataset_train_txt(root, txt_file, train_image_format)
  138.         self.imgs = imgs
  139.         self.transform = transform
  140.         self.target_transform = target_transform
  141.         self.classNum = classNum
  142.         self.train_image_format = train_image_format
  143.         self.img_channel = img_channel
  144.     def __getitem__(self, index):
  145.         x_path, y_path = self.imgs[index]
  146.         if(self.img_channel == 3):
  147.             image = cv2.imread(x_path, 1)
  148.             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  149.         elif(self.img_channel == 1):
  150.             image = cv2.imread(x_path, 0)
  151.         else:
  152.             return
  153.         mask = cv2.imread(y_path, 0)
  154.         # extract certain classes from mask

  155.         mask = np.stack(masks, axis=-1).astype('float')
  156.         
  157.         if self.transform is not None:
  158.             image = self.transform(image)
  159.         if self.target_transform is not None:
  160.             mask = self.target_transform(mask)
  161.         return image, mask
  162.     def __len__(self):
  163.         return len(self.imgs)
  164. class val_LiverDataset(data.Dataset):
  165.     def __init__(self, root, txt_file, classNum = 1, transform = None, target_transform = None, train_image_format = ".jpg", img_channel = 3):
  166.         imgs = make_dataset_val_txt(root, txt_file, train_image_format)
  167.         self.imgs = imgs
  168.         self.transform = transform
  169.         self.target_transform = target_transform
  170.         self.classNum = classNum
  171.         self.train_image_format = train_image_format
  172.         self.img_channel = img_channel
  173.     def __getitem__(self, index):
  174.         x_path, y_path = self.imgs[index]
  175.         if(self.img_channel == 3):
  176.             image = cv2.imread(x_path, 1)
  177.             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  178.         elif(self.img_channel == 1):
  179.             image = cv2.imread(x_path, 0)
  180.         else:
  181.             return
  182.         mask = cv2.imread(y_path, 0)
  183.         # extract certain classes from mask

  184.         mask = np.stack(masks, axis=-1).astype('float')
  185.         
  186.         if self.transform is not None:
  187.             image = self.transform(image)
  188.         if self.target_transform is not None:
  189.             mask = self.target_transform(mask)
  190.         return image, mask
  191.     def __len__(self):
  192.         return len(self.imgs)
  193. # training model
  194. def trainAI(train_project_path, xmlfile, img_root_path, mask_root_path, train_txt_file_path, train_loss_txt_file, val_loss_txt_file, trainbatch_loss_txt_file, train_image_format, initial_model, GPU_index, img_channel, img_size, batch_size, num_workers, num_epochs, Seg_p, lr_init, moment_init, classNum, ENCODER, ENCODER_WEIGHTS, AlphaBet):
  195.     # 是否使用cuda
  196.     device = torch.device("cuda:" + str(GPU_index) if torch.cuda.is_available() else "cpu")
  197.     classNum = len(AlphaBet)
  198.     model = EfficientNet.from_pretrained('efficientnet-b0', in_channels=img_channel, num_classes= len(AlphaBet)).to(device)
  199.     criterion = torch.nn.CrossEntropyLoss().to(device)
  200.     optimizer = torch.optim.Adam([dict(params=model.parameters(), lr = lr_init)])
  201.     if img_channel == 1:
  202.         x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ])
  203.     elif img_channel == 3:
  204.         x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])
  205.         # x_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
  206.     else:
  207.         return
  208.     y_transforms = transforms.ToTensor()
  209.     train_liver_dataset = train_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_train.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
  210.     train_dataloaders = DataLoader(train_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
  211.     val_liver_dataset = val_LiverDataset(train_txt_file_path, train_txt_file_path + "/Mask_validation.txt", classNum = classNum, transform = x_transforms, target_transform = y_transforms, train_image_format = train_image_format, img_channel = img_channel)
  212.     val_dataloaders = DataLoader(val_liver_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
  213.     if initial_model != "None":
  214.         model = torch.jit.load(initial_model)
  215.     model.cuda(GPU_index)
  216.     model.train()
  217.     train_sum_loss = np.inf
  218.     val_sum_loss = np.inf
  219.     for epoch in range(num_epochs):
  220.         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  221.         dt_size = len(train_dataloaders.dataset)
  222.         step = 0
  223.         epoch_loss = 0
  224.         model.train()
  225.         for x, y in train_dataloaders:
  226.             step += 1
  227.             inputs = x.to(device)
  228.             labels = y.to(device)
  229.             optimizer.zero_grad()
  230.             outputs = model(inputs)
  231.             loss_dict = criterion(outputs, labels)
  232.             loss.backward()
  233.             optimizer.step()
  234.             epoch_loss += loss.item()

  235.             print("%d/%d,train_loss:%0.3f,train_metric:%0.3f" % (step, (dt_size - 1) // train_dataloaders.batch_size + 1, loss.item(), epoch_metric))
  236.             del x, y, inputs, labels, outputs
  237.             torch.cuda.empty_cache()
  238.         # print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
  239.         if(epoch_loss < train_sum_loss):
  240.             train_sum_loss = epoch_loss
  241.             model.eval()
  242.             example = torch.rand(1, img_channel, img_size, img_size).type(torch.FloatTensor).cuda(GPU_index)
  243.             pt_model = torch.jit.trace(model, example)
  244.             pt_model.save( os.path.join(train_project_path, "Model", str(epoch) +'-Loss.pt') )
复制代码

参考:
【1】https://github.com/lukemelas/EfficientNet-PyTorch
【2】百度网盘连接:链接:https://pan.baidu.com/s/1dVcTded9PD-8Oy270bAGhA 提取码:u60r
【3】Efficientnet_pytorch训练&#12163;&#12080;的数据集,并对数据进&#12175;分类
【4】【图像分类】——来来来,干了这碗EfficientNet实战(Pytorch)


回复

使用道具 举报

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2023-2-5 20:05:44 | 显示全部楼层

  1. efficientnet              1.1.1                    pypi_0    pypi
  2. efficientnet-pytorch      0.6.3                    pypi_0    pypi

  3. torch                     1.4.0                    pypi_0    pypi
  4. torchvision               0.5.0                    pypi_0    pypi

  5. python                    3.7.6                h60c2a47_2
  6. python-dateutil           2.8.1                      py_0
  7. python-jsonrpc-server     0.3.4                      py_0
  8. python-language-server    0.31.7                   py37_0
  9. python-libarchive-c       2.8                     py37_13

  10. opencv-python             4.4.0.46                 pypi_0    pypi
  11. openpyxl                  3.0.3                      py_0
  12. openssl                   1.1.1d               he774522_4
复制代码

算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-4-19 05:46 , Processed in 0.230903 second(s), 24 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表