Hello Mat

 找回密码
 立即注册
查看: 5404|回复: 0

OCR深度学习模型CRNN

[复制链接]

1323

主题

1551

帖子

0

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22647
发表于 2021-5-30 16:52:30 | 显示全部楼层 |阅读模式
OCR深度学习模型CRNN:CNN + RNN + CTCLoss
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import torchvision.models as models
  5. import string
  6. import numpy as np

  7. class CRNN(nn.Module):
  8.     def __init__(self,
  9.                  abc=string.digits,
  10.                  backend='resnet18',
  11.                  rnn_hidden_size=128,
  12.                  rnn_num_layers=2,
  13.                  rnn_dropout=0,
  14.                  seq_proj=[0, 0]):
  15.         super(CRNN, self).__init__()

  16.         self.abc = abc
  17.         self.num_classes = len(self.abc)

  18.         self.feature_extractor = getattr(models, backend)(pretrained=True)
  19.         self.cnn = nn.Sequential(
  20.             self.feature_extractor.conv1,
  21.             self.feature_extractor.bn1,
  22.             self.feature_extractor.relu,
  23.             self.feature_extractor.maxpool,
  24.             self.feature_extractor.layer1,
  25.             self.feature_extractor.layer2,
  26.             self.feature_extractor.layer3,
  27.             self.feature_extractor.layer4
  28.         )

  29.         self.fully_conv = seq_proj[0] == 0
  30.         if not self.fully_conv:
  31.             self.proj = nn.Conv2d(seq_proj[0], seq_proj[1], kernel_size=1)

  32.         self.rnn_hidden_size = rnn_hidden_size
  33.         self.rnn_num_layers = rnn_num_layers
  34.         self.rnn = nn.GRU(self.get_block_size(self.cnn),
  35.                           rnn_hidden_size, rnn_num_layers,
  36.                           batch_first=False,
  37.                           dropout=rnn_dropout, bidirectional=True)
  38.         self.linear = nn.Linear(rnn_hidden_size * 2, self.num_classes + 1)
  39.         self.softmax = nn.Softmax(dim=2)

  40.     def forward(self, x, decode=False):
  41.         hidden = self.init_hidden(x.size(0), next(self.parameters()).is_cuda)
  42.         features = self.cnn(x)
  43.         features = self.features_to_sequence(features)
  44.         seq, hidden = self.rnn(features, hidden)
  45.         seq = self.linear(seq)
  46.         if not self.training:
  47.             seq = self.softmax(seq)
  48.             if decode:
  49.                 seq = self.decode(seq)
  50.         return seq

  51.     def init_hidden(self, batch_size, gpu=False):
  52.         h0 = Variable(torch.zeros( self.rnn_num_layers * 2,
  53.                                    batch_size,
  54.                                    self.rnn_hidden_size))
  55.         if gpu:
  56.             h0 = h0.cuda()
  57.         return h0

  58.     def features_to_sequence(self, features):
  59.         b, c, h, w = features.size()
  60.         assert h == 1, "the height of out must be 1"
  61.         if not self.fully_conv:
  62.             features = features.permute(0, 3, 2, 1)
  63.             features = self.proj(features)
  64.             features = features.permute(1, 0, 2, 3)
  65.         else:
  66.             features = features.permute(3, 0, 2, 1)
  67.         features = features.squeeze(2)
  68.         return features

  69.     def get_block_size(self, layer):
  70.         return layer[-1][-1].bn2.weight.size()[0]

  71.     def pred_to_string(self, pred):
  72.         seq = []
  73.         for i in range(pred.shape[0]):
  74.             label = np.argmax(pred[i])
  75.             seq.append(label - 1)
  76.         out = []
  77.         for i in range(len(seq)):
  78.             if len(out) == 0:
  79.                 if seq[i] != -1:
  80.                     out.append(seq[i])
  81.             else:
  82.                 if seq[i] != -1 and seq[i] != seq[i - 1]:
  83.                     out.append(seq[i])
  84.         out = ''.join(self.abc[i] for i in out)
  85.         return out

  86.     def decode(self, pred):
  87.         pred = pred.permute(1, 0, 2).cpu().data.numpy()
  88.         seq = []
  89.         for i in range(pred.shape[0]):
  90.             seq.append(self.pred_to_string(pred[i]))
  91.         return seq
复制代码




算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-22 23:50 , Processed in 0.222770 second(s), 24 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表