OCR深度学习模型CRNN
OCR深度学习模型CRNN:CNN + RNN + CTCLossimport torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import string
import numpy as np
class CRNN(nn.Module):
def __init__(self,
abc=string.digits,
backend='resnet18',
rnn_hidden_size=128,
rnn_num_layers=2,
rnn_dropout=0,
seq_proj=):
super(CRNN, self).__init__()
self.abc = abc
self.num_classes = len(self.abc)
self.feature_extractor = getattr(models, backend)(pretrained=True)
self.cnn = nn.Sequential(
self.feature_extractor.conv1,
self.feature_extractor.bn1,
self.feature_extractor.relu,
self.feature_extractor.maxpool,
self.feature_extractor.layer1,
self.feature_extractor.layer2,
self.feature_extractor.layer3,
self.feature_extractor.layer4
)
self.fully_conv = seq_proj == 0
if not self.fully_conv:
self.proj = nn.Conv2d(seq_proj, seq_proj, kernel_size=1)
self.rnn_hidden_size = rnn_hidden_size
self.rnn_num_layers = rnn_num_layers
self.rnn = nn.GRU(self.get_block_size(self.cnn),
rnn_hidden_size, rnn_num_layers,
batch_first=False,
dropout=rnn_dropout, bidirectional=True)
self.linear = nn.Linear(rnn_hidden_size * 2, self.num_classes + 1)
self.softmax = nn.Softmax(dim=2)
def forward(self, x, decode=False):
hidden = self.init_hidden(x.size(0), next(self.parameters()).is_cuda)
features = self.cnn(x)
features = self.features_to_sequence(features)
seq, hidden = self.rnn(features, hidden)
seq = self.linear(seq)
if not self.training:
seq = self.softmax(seq)
if decode:
seq = self.decode(seq)
return seq
def init_hidden(self, batch_size, gpu=False):
h0 = Variable(torch.zeros( self.rnn_num_layers * 2,
batch_size,
self.rnn_hidden_size))
if gpu:
h0 = h0.cuda()
return h0
def features_to_sequence(self, features):
b, c, h, w = features.size()
assert h == 1, "the height of out must be 1"
if not self.fully_conv:
features = features.permute(0, 3, 2, 1)
features = self.proj(features)
features = features.permute(1, 0, 2, 3)
else:
features = features.permute(3, 0, 2, 1)
features = features.squeeze(2)
return features
def get_block_size(self, layer):
return layer[-1][-1].bn2.weight.size()
def pred_to_string(self, pred):
seq = []
for i in range(pred.shape):
label = np.argmax(pred)
seq.append(label - 1)
out = []
for i in range(len(seq)):
if len(out) == 0:
if seq != -1:
out.append(seq)
else:
if seq != -1 and seq != seq:
out.append(seq)
out = ''.join(self.abc for i in out)
return out
def decode(self, pred):
pred = pred.permute(1, 0, 2).cpu().data.numpy()
seq = []
for i in range(pred.shape):
seq.append(self.pred_to_string(pred))
return seq
页:
[1]