提交 897ca5b2 编写于 作者: X xiaoqianyu

CRNN keras模型支持LSTM

上级 a017ba21
......@@ -2,6 +2,7 @@
from crnn.utils import strLabelConverter,resizeNormalize
from crnn.network_keras import keras_crnn as CRNN
from config import LSTMFLAG
import tensorflow as tf
graph = tf.get_default_graph()##解决web.py 相关报错问题
......@@ -11,7 +12,7 @@ import numpy as np
def crnnSource():
alphabet = keys.alphabetChinese##中英文模型
converter = strLabelConverter(alphabet)
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=False)
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG)
model.load_weights(ocrModelKeras)
return model,converter
......@@ -37,6 +38,3 @@ def crnnOcr(image):
preds = np.argmax(preds,axis=2).reshape((-1,))
sim_pred = converter.decode(preds)
return sim_pred
from keras.layers import Conv2D,BatchNormalization,MaxPool2D,Input,Permute,Reshape,Dense,LeakyReLU,Activation
from keras.layers import (Conv2D,BatchNormalization,MaxPool2D,Input,Permute,Reshape,Dense,LeakyReLU,Activation, Bidirectional, LSTM, TimeDistributed)
from keras.models import Model
from keras.layers import ZeroPadding2D
from keras.activations import relu
......@@ -68,7 +68,17 @@ def keras_crnn(imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False,lstmFlag=True):
x = Permute((3, 2, 1))(x)
x = Reshape((-1,512))(x)
out = Dense(nclass,name='linear')(x)
out = None
if lstmFlag:
x = Bidirectional(LSTM(nh, return_sequences=True, use_bias=True,
recurrent_activation='sigmoid'))(x)
x = TimeDistributed(Dense(nh))(x)
x = Bidirectional(LSTM(nh, return_sequences=True, use_bias=True,
recurrent_activation='sigmoid'))(x)
out = TimeDistributed(Dense(nclass))(x)
else:
out = Dense(nclass,name='linear')(x)
out = Reshape((-1, 1, nclass),name='out')(out)
return Model(imgInput,out)
\ No newline at end of file
return Model(imgInput,out)
......@@ -12,6 +12,8 @@ def parser():
parser = argparse.ArgumentParser(description="pytorch dense ocr to keras ocr")
parser.add_argument('-weights_path',help='models/ocr-dense.pth')
parser.add_argument('-output_path', help='models/ocr-dense-keras.h5')
parser.add_argument('-lstm', default=False,
action='store_true', help='translate lstm layer')
return parser.parse_args()
def set_cnn_weight(name,keramodel,torchmodelDict):
......@@ -68,7 +70,62 @@ def set_dense_weight(name,keramodel,torchmodelDict):
if weight is not None and bias is not None:
weight = np.transpose(weight)
keramodel.get_layer(name).set_weights([weight,bias])
def set_lstm_weight(name, kerasmodel, torchmodelDict):
# RNN
weight_ih_l0 = None
weight_hh_l0 = None
bias_ih_l0 = None
bias_hh_l0 = None
weight_ih_l0_reverse = None
weight_hh_l0_reverse = None
bias_ih_l0_reverse = None
bias_hh_l0_reverse = None
# TimeDistributed
embedding_weight = None
embedding_bias = None
for key in torchmodelDict:
if name in key:
if key.endswith('rnn.weight_ih_l0'):
weight_ih_l0 = torchmodelDict[key]
elif key.endswith('rnn.weight_hh_l0'):
weight_hh_l0 = torchmodelDict[key]
elif key.endswith('rnn.bias_ih_l0'):
bias_ih_l0 = torchmodelDict[key]
elif key.endswith('rnn.bias_hh_l0'):
bias_hh_l0 = torchmodelDict[key]
elif key.endswith('rnn.weight_ih_l0_reverse'):
weight_ih_l0_reverse = torchmodelDict[key]
elif key.endswith('rnn.weight_hh_l0_reverse'):
weight_hh_l0_reverse = torchmodelDict[key]
elif key.endswith('rnn.bias_ih_l0_reverse'):
bias_ih_l0_reverse = torchmodelDict[key]
elif key.endswith('rnn.bias_hh_l0_reverse'):
bias_hh_l0_reverse = torchmodelDict[key]
elif key.endswith('embedding.weight'):
embedding_weight = torchmodelDict[key]
elif key.endswith('embedding.bias'):
embedding_bias = torchmodelDict[key]
rnn_weights = [
weight_ih_l0.transpose(1, 0),
weight_hh_l0.transpose(1, 0),
(bias_ih_l0 + bias_hh_l0),
weight_ih_l0_reverse.transpose(1, 0),
weight_hh_l0_reverse.transpose(1, 0),
(bias_ih_l0_reverse + bias_hh_l0_reverse)
]
linear_weights = [
embedding_weight.transpose(1, 0).numpy(),
embedding_bias.numpy(),
]
if name == 'rnn.0':
kerasmodel.get_layer('bidirectional_1').set_weights(rnn_weights)
kerasmodel.get_layer('time_distributed_1').set_weights(linear_weights)
else:
kerasmodel.get_layer('bidirectional_2').set_weights(rnn_weights)
kerasmodel.get_layer('time_distributed_2').set_weights(linear_weights)
if __name__=='__main__':
import os
import sys
......@@ -81,10 +138,11 @@ if __name__=='__main__':
from collections import OrderedDict
from crnn.keys import alphabetChinese
from crnn.network_keras import keras_crnn
##ocrModel='models/ocr-dense.pth'##目前只支持 dense ocr
##ocrModel='models/ocr-dense.pth' #dense ocr
##ocrModel='models/ocr-lstm.pth' #lstm ocr
ocrModel = args.weights_path##torch模型权重
output_path =args.output_path##keras 模型权重输出
kerasModel = keras_crnn(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=False)
kerasModel = keras_crnn(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=args.lstm)
state_dict = torch.load(ocrModel,map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
......@@ -96,17 +154,19 @@ if __name__=='__main__':
cnn = ['cnn.conv0','cnn.conv1','cnn.conv2','cnn.conv3','cnn.conv4','cnn.conv5','cnn.conv6']
BN =['cnn.batchnorm2','cnn.batchnorm4','cnn.batchnorm6']
linear = ['linear']
lstm = ['rnn.0', 'rnn.1']
##CNN 层
for cn in cnn:
set_cnn_weight(cn,kerasModel,new_state_dict)
##BN 层
for bn in BN:
set_bn_weight(bn,kerasModel,new_state_dict)
## linear 层
for lr in linear:
set_dense_weight(lr,kerasModel,new_state_dict)
if args.lstm:
for l in lstm:
set_lstm_weight(l,kerasModel,new_state_dict)
else:
## linear 层
for lr in linear:
set_dense_weight(lr,kerasModel,new_state_dict)
kerasModel.save_weights(output_path)##保存keras权重
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册