crnn.py 2.0 KB
Newer Older
L
lywen 已提交
1 2 3 4 5 6 7 8 9 10 11
#coding:utf-8
import sys
sys.path.insert(1, "./crnn")
import torch
import torch.utils.data
from torch.autograd import Variable 
from crnn import util
from crnn import dataset
from crnn.models import crnn as crnn
from crnn import keys
from collections import OrderedDict
L
lywen 已提交
12 13
from config import ocrModel,LSTMFLAG,GPU
from config import chinsesModel
L
lywen 已提交
14
def crnnSource():
L
lywen 已提交
15 16 17 18 19
    if chinsesModel:
        alphabet = keys.alphabetChinese
    else:
        alphabet = keys.alphabetEnglish
        
L
lywen 已提交
20 21
    converter = util.strLabelConverter(alphabet)
    if torch.cuda.is_available() and GPU:
L
lywen 已提交
22
        model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
L
lywen 已提交
23
    else:
L
lywen 已提交
24 25
        model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()
    
26
    state_dict = torch.load(ocrModel,map_location=lambda storage, loc: storage)
L
lywen 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v
    # load params
   
    model.load_state_dict(new_state_dict)
    model.eval()
    
    return model,converter

##加载模型
model,converter = crnnSource()

def crnnOcr(image):
       """
       crnn模型,ocr识别
       @@model,
       @@converter,
       @@im
       @@text_recs:text box

       """
       scale = image.size[1]*1.0 / 32
       w = image.size[0] / scale
       w = int(w)
       #print "im size:{},{}".format(image.size,w)
       transformer = dataset.resizeNormalize((w, 32))
       if torch.cuda.is_available() and GPU:
           image = transformer(image).cuda()
       else:
           image = transformer(image).cpu()
            
       image = image.view(1, *image.size())
       image = Variable(image)
       model.eval()
       preds = model(image)
       _, preds = preds.max(2)
       preds = preds.transpose(1, 0).contiguous().view(-1)
       preds_size = Variable(torch.IntTensor([preds.size(0)]))
       sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

       return sim_pred