Hello Mat

 找回密码
 立即注册
查看: 5412|回复: 0

OCR深度学习模型CRNN+BiLSTM 模型2

[复制链接]

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2021-5-30 16:55:25 | 显示全部楼层 |阅读模式
OCR深度学习模型CRNN+BiLSTM
  1. import torch
  2. import torch.nn as nn

  3. def conv3x3(in_planes, out_planes, stride=1):
  4.     """3x3 convolution with padding"""
  5.     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  6.                    padding=1, bias=False)

  7. def conv1x1(in_planes, out_planes, stride=1):
  8.     """1x1 convolution"""
  9.     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

  10. class AsterBlock(nn.Module):
  11.     def __init__(self, inplanes, planes, stride=1, downsample=None):
  12.         super(AsterBlock, self).__init__()
  13.         self.conv1 = conv1x1(inplanes, planes, stride)
  14.         self.bn1 = nn.BatchNorm2d(planes)
  15.         self.relu = nn.ReLU(inplace=True)
  16.         self.conv2 = conv3x3(planes, planes)
  17.         self.bn2 = nn.BatchNorm2d(planes)
  18.         self.downsample = downsample
  19.         self.stride = stride

  20.     def forward(self, x):
  21.         residual = x
  22.         out = self.conv1(x)
  23.         out = self.bn1(out)
  24.         out = self.relu(out)
  25.         out = self.conv2(out)
  26.         out = self.bn2(out)

  27.         if self.downsample is not None:
  28.             residual = self.downsample(x)
  29.         out += residual
  30.         out = self.relu(out)
  31.         return out

  32. class ResNet_ASTER(nn.Module):
  33.     """For aster or crnn"""

  34.     def __init__(self, num_class, with_lstm=False):
  35.         super(ResNet_ASTER, self).__init__()
  36.         self.with_lstm = with_lstm

  37.         in_channels = 1
  38.         self.layer0 = nn.Sequential(
  39.             nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False),
  40.             nn.BatchNorm2d(32),
  41.             nn.ReLU(inplace=True))

  42.         self.inplanes = 32
  43.         self.layer1 = self._make_layer(32,  3, [2, 2]) # [16]
  44.         self.layer2 = self._make_layer(64,  4, [2, 2]) # [8]
  45.         self.layer3 = self._make_layer(128, 6, [2, 1]) # [4]
  46.         self.layer4 = self._make_layer(256, 6, [2, 1]) # [2]
  47.         self.layer5 = self._make_layer(512, 3, [2, 1]) # [1]

  48.         self.output_layer = nn.Linear(512,num_class)

  49.         if with_lstm:
  50.             self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2)
  51.             self.out_planes = 2 * 256
  52.         else:
  53.             self.out_planes = 512

  54.         for m in self.modules():
  55.             if isinstance(m, nn.Conv2d):
  56.                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  57.             elif isinstance(m, nn.BatchNorm2d):
  58.                 nn.init.constant_(m.weight, 1)
  59.                 nn.init.constant_(m.bias, 0)
  60.         nn.init.normal_(self.output_layer.weight,std=0.01)
  61.         nn.init.constant_(self.output_layer.bias,0)

  62.     def _make_layer(self, planes, blocks, stride):
  63.         downsample = None
  64.         if stride != [1, 1] or self.inplanes != planes:
  65.             downsample = nn.Sequential(
  66.                 conv1x1(self.inplanes, planes, stride),
  67.                 nn.BatchNorm2d(planes))

  68.         layers = []
  69.         layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
  70.         self.inplanes = planes
  71.         for _ in range(1, blocks):
  72.             layers.append(AsterBlock(self.inplanes, planes))
  73.         return nn.Sequential(*layers)

  74.     def forward(self, x):
  75.         x0 = self.layer0(x)
  76.         x1 = self.layer1(x0)
  77.         x2 = self.layer2(x1)
  78.         x3 = self.layer3(x2)
  79.         x4 = self.layer4(x3)
  80.         x5 = self.layer5(x4)

  81.         cnn_feat = x5.squeeze(2) # [N, c, w]
  82.         cnn_feat = cnn_feat.permute(2,0,1) #[T, b, input_size]
  83.         if self.with_lstm:
  84.             rnn_feat, _ = self.rnn(cnn_feat)
  85.             T,b,h = rnn_feat.size()
  86.             output = rnn_feat.view(T*b,h)
  87.             output = self.output_layer(output)
  88.             output = output.view(T,b,-1)
  89.             output = nn.functional.log_softmax(output,dim=2)
  90.             return output
  91.         else:
  92.             return cnn_feat


  93. def get_crnn(config):
  94.     assert config.MODEL.IMAGE_SIZE.H == 32, 'imgH has to be a multiple of 32'
  95.     return ResNet_ASTER(config.MODEL.NUM_CLASSES + 1, True)

  96. if __name__ == "__main__":
  97.     from torchsummary import summary

  98.     model = ResNet_ASTER(35, True)
  99.     model.eval()

  100.     summary(model, (3, 32, 288))
复制代码



算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-4-19 19:58 , Processed in 0.238121 second(s), 25 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表