未验证 提交 c708041e 编写于 作者: Z zhoujun 提交者: GitHub

CRNN导出 (#1159)

* 识别模型导出

* 识别模型inference
上级 882ad395
...@@ -26,34 +26,27 @@ import time ...@@ -26,34 +26,27 @@ import time
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger from ppocr.postprocess import build_post_process
logger = initial_logger() from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.character import CharacterOps
class TextRecognizer(object): class TextRecognizer(object):
def __init__(self, args): def __init__(self, args):
self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="rec")
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = { postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char, "use_space_char": args.use_space_char
"max_text_length": args.max_text_length
} }
if self.rec_algorithm != "RARE": self.postprocess_op = build_post_process(postprocess_params)
char_ops_params['loss_type'] = 'ctc' self.predictor, self.input_tensor, self.output_tensors = \
self.loss_type = 'ctc' utility.create_predictor(args, 'rec', logger)
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
...@@ -112,48 +105,14 @@ class TextRecognizer(object): ...@@ -112,48 +105,14 @@ class TextRecognizer(object):
else: else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch]) self.predictor.run([norm_img_batch])
outputs = []
if self.loss_type == "ctc": for output_tensor in self.output_tensors:
rec_idx_batch = self.output_tensors[0].copy_to_cpu() output = output_tensor.copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0] outputs.append(output)
predict_batch = self.output_tensors[1].copy_to_cpu() preds = outputs[0]
predict_lod = self.output_tensors[1].lod()[0] rec_res = self.postprocess_op(preds)
elapse = time.time() - starttime elapse = time.time() - starttime
predict_time += elapse return rec_res, elapse
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = predict_batch[beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
return rec_res, predict_time
def main(args): def main(args):
...@@ -183,9 +142,10 @@ def main(args): ...@@ -183,9 +142,10 @@ def main(args):
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images, cost: %.3f" %
(len(img_list), predict_time)) (len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args()) main(utility.parse_args())
...@@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return metirc return metirc
def save_inference_mode(model, config, logger):
model.eval()
save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'],
config['Architecture']['model_type'])
if config['Architecture']['model_type'] == 'rec':
input_shape = [None, 3, 32, None]
jit_model = paddle.jit.to_static(
model, input_spec=[paddle.static.InputSpec(input_shape)])
paddle.jit.save(jit_model, save_path)
logger.info('inference model save to {}'.format(save_path))
model.train()
def preprocess(): def preprocess():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
...@@ -334,7 +348,7 @@ def preprocess(): ...@@ -334,7 +348,7 @@ def preprocess():
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
...@@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer): ...@@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer):
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer)
program.save_inference_mode(model, config, logger)
def test_reader(config, device, logger): def test_reader(config, device, logger):
...@@ -102,8 +103,8 @@ def test_reader(config, device, logger): ...@@ -102,8 +103,8 @@ def test_reader(config, device, logger):
if count % 1 == 0: if count % 1 == 0:
batch_time = time.time() - starttime batch_time = time.time() - starttime
starttime = time.time() starttime = time.time()
logger.info("reader: {}, {}, {}".format(count, logger.info("reader: {}, {}, {}".format(
len(data), batch_time)) count, len(data[0]), batch_time))
except Exception as e: except Exception as e:
logger.info(e) logger.info(e)
logger.info("finish reader: {}, Success!".format(count)) logger.info("finish reader: {}, Success!".format(count))
...@@ -112,4 +113,4 @@ def test_reader(config, device, logger): ...@@ -112,4 +113,4 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册