提交 60383ca1 编写于 作者: L lywen

新增 dense ocr

上级 f449fe61
......@@ -8,19 +8,21 @@ from crnn import util
from crnn import dataset
from crnn.models import crnn as crnn
from crnn import keys
#from conf import crnnModelPath
#from conf import GPU
GPU=False
from collections import OrderedDict
from config import ocrModel
from config import ocrModel,LSTMFLAG,GPU
from config import chinsesModel
def crnnSource():
alphabet = keys.alphabet
if chinsesModel:
alphabet = keys.alphabetChinese
else:
alphabet = keys.alphabetEnglish
converter = util.strLabelConverter(alphabet)
if torch.cuda.is_available() and GPU:
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cuda()
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
else:
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cpu()
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()
state_dict = torch.load(ocrModel,map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
......
此差异已折叠。
import torch.nn as nn
from . import utils
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut, ngpu):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.ngpu = ngpu
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = utils.data_parallel(
self.rnn, input, self.ngpu) # [T, b, h * 2]
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = utils.data_parallel(
self.embedding, t_rec, self.ngpu) # [T * b, nOut]
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False,lstmFlag=True):
"""
是否加入lstm特征层
"""
super(CRNN, self).__init__()
self.ngpu = ngpu
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
self.lstmFlag = lstmFlag
cnn = nn.Sequential()
......@@ -57,31 +55,41 @@ class CRNN(nn.Module):
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2),
(2, 1),
(0, 1))) # 256x4x16
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2),
(2, 1),
(0, 1))) # 512x2x16
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh, ngpu),
BidirectionalLSTM(nh, nh, nclass, ngpu)
)
if self.lstmFlag:
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
else:
self.linear = nn.Linear(nh*2, nclass)
def forward(self, input):
# conv features
conv = utils.data_parallel(self.cnn, input, self.ngpu)
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
output = utils.data_parallel(self.rnn, conv, self.ngpu)
if self.lstmFlag:
# rnn features
output = self.rnn(conv)
else:
T, b, h = conv.size()
t_rec = conv.contiguous().view(T * b, h)
output = self.linear(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册