Halcom 发表于 2021-5-30 16:52:30

OCR深度学习模型CRNN

OCR深度学习模型CRNN:CNN + RNN + CTCLoss
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import string
import numpy as np

class CRNN(nn.Module):
    def __init__(self,
               abc=string.digits,
               backend='resnet18',
               rnn_hidden_size=128,
               rnn_num_layers=2,
               rnn_dropout=0,
               seq_proj=):
      super(CRNN, self).__init__()

      self.abc = abc
      self.num_classes = len(self.abc)

      self.feature_extractor = getattr(models, backend)(pretrained=True)
      self.cnn = nn.Sequential(
            self.feature_extractor.conv1,
            self.feature_extractor.bn1,
            self.feature_extractor.relu,
            self.feature_extractor.maxpool,
            self.feature_extractor.layer1,
            self.feature_extractor.layer2,
            self.feature_extractor.layer3,
            self.feature_extractor.layer4
      )

      self.fully_conv = seq_proj == 0
      if not self.fully_conv:
            self.proj = nn.Conv2d(seq_proj, seq_proj, kernel_size=1)

      self.rnn_hidden_size = rnn_hidden_size
      self.rnn_num_layers = rnn_num_layers
      self.rnn = nn.GRU(self.get_block_size(self.cnn),
                        rnn_hidden_size, rnn_num_layers,
                        batch_first=False,
                        dropout=rnn_dropout, bidirectional=True)
      self.linear = nn.Linear(rnn_hidden_size * 2, self.num_classes + 1)
      self.softmax = nn.Softmax(dim=2)

    def forward(self, x, decode=False):
      hidden = self.init_hidden(x.size(0), next(self.parameters()).is_cuda)
      features = self.cnn(x)
      features = self.features_to_sequence(features)
      seq, hidden = self.rnn(features, hidden)
      seq = self.linear(seq)
      if not self.training:
            seq = self.softmax(seq)
            if decode:
                seq = self.decode(seq)
      return seq

    def init_hidden(self, batch_size, gpu=False):
      h0 = Variable(torch.zeros( self.rnn_num_layers * 2,
                                 batch_size,
                                 self.rnn_hidden_size))
      if gpu:
            h0 = h0.cuda()
      return h0

    def features_to_sequence(self, features):
      b, c, h, w = features.size()
      assert h == 1, "the height of out must be 1"
      if not self.fully_conv:
            features = features.permute(0, 3, 2, 1)
            features = self.proj(features)
            features = features.permute(1, 0, 2, 3)
      else:
            features = features.permute(3, 0, 2, 1)
      features = features.squeeze(2)
      return features

    def get_block_size(self, layer):
      return layer[-1][-1].bn2.weight.size()

    def pred_to_string(self, pred):
      seq = []
      for i in range(pred.shape):
            label = np.argmax(pred)
            seq.append(label - 1)
      out = []
      for i in range(len(seq)):
            if len(out) == 0:
                if seq != -1:
                  out.append(seq)
            else:
                if seq != -1 and seq != seq:
                  out.append(seq)
      out = ''.join(self.abc for i in out)
      return out

    def decode(self, pred):
      pred = pred.permute(1, 0, 2).cpu().data.numpy()
      seq = []
      for i in range(pred.shape):
            seq.append(self.pred_to_string(pred))
      return seq




页: [1]
查看完整版本: OCR深度学习模型CRNN