Hello Mat

 找回密码
 立即注册
查看: 6081|回复: 0

OCR深度学习模型CRNN+BiLSTM 模型1

[复制链接]

1323

主题

1551

帖子

0

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22647
发表于 2021-5-30 16:54:40 | 显示全部楼层 |阅读模式
OCR深度学习模型CRNN+BiLSTM:
  1. import torch.nn as nn
  2. import torch.nn.functional as F

  3. class BidirectionalLSTM(nn.Module):
  4.     # Inputs hidden units Out
  5.     def __init__(self, nIn, nHidden, nOut):
  6.         super(BidirectionalLSTM, self).__init__()

  7.         self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  8.         self.embedding = nn.Linear(nHidden * 2, nOut)

  9.     def forward(self, input):
  10.         recurrent, _ = self.rnn(input)
  11.         T, b, h = recurrent.size()
  12.         t_rec = recurrent.view(T * b, h)

  13.         output = self.embedding(t_rec)  # [T * b, nOut]
  14.         output = output.view(T, b, -1)

  15.         return output

  16. class CRNN(nn.Module):
  17.     def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
  18.         super(CRNN, self).__init__()
  19.         assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

  20.         ks = [3, 3, 3, 3, 3, 3, 2]
  21.         ps = [1, 1, 1, 1, 1, 1, 0]
  22.         ss = [1, 1, 1, 1, 1, 1, 1]
  23.         nm = [64, 128, 256, 256, 512, 512, 512]

  24.         cnn = nn.Sequential()

  25.         def convRelu(i, batchNormalization=False):
  26.             nIn = nc if i == 0 else nm[i - 1]
  27.             nOut = nm[i]
  28.             cnn.add_module('conv{0}'.format(i),
  29.                            nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  30.             if batchNormalization:
  31.                 cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  32.             if leakyRelu:
  33.                 cnn.add_module('relu{0}'.format(i),
  34.                                nn.LeakyReLU(0.2, inplace=True))
  35.             else:
  36.                 cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

  37.         convRelu(0)
  38.         cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
  39.         convRelu(1)
  40.         cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
  41.         convRelu(2, True)
  42.         convRelu(3)
  43.         cnn.add_module('pooling{0}'.format(2),
  44.                        nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
  45.         convRelu(4, True)
  46.         convRelu(5)
  47.         cnn.add_module('pooling{0}'.format(3),
  48.                        nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
  49.         convRelu(6, True)  # 512x1x16

  50.         self.cnn = cnn
  51.         self.rnn = nn.Sequential(
  52.             BidirectionalLSTM(512, nh, nh),
  53.             BidirectionalLSTM(nh, nh, nclass))

  54.     def forward(self, input):

  55.         # conv features
  56.         conv = self.cnn(input)
  57.         b, c, h, w = conv.size()
  58.         # print(conv.size())
  59.         assert h == 1, "the height of conv must be 1"
  60.         conv = conv.squeeze(2) # b *512 * width
  61.         conv = conv.permute(2, 0, 1)  # [w, b, c]
  62.         output = F.log_softmax(self.rnn(conv), dim=2)

  63.         return output

  64. def weights_init(m):
  65.     classname = m.__class__.__name__
  66.     if classname.find('Conv') != -1:
  67.         m.weight.data.normal_(0.0, 0.02)
  68.     elif classname.find('BatchNorm') != -1:
  69.         m.weight.data.normal_(1.0, 0.02)
  70.         m.bias.data.fill_(0)

  71. def get_crnn(config):

  72.     model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
  73.     model.apply(weights_init)

  74.     return model
复制代码



算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-22 23:14 , Processed in 0.232720 second(s), 23 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表