From bc1ad2070114e74234db4b300f6660c77511ca76 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 3 Sep 2020 15:51:50 +0800 Subject: [PATCH] support srn inference --- ppocr/modeling/architectures/rec_model.py | 16 ++- tools/infer/predict_rec.py | 163 ++++++++++++++++++++-- tools/infer/utility.py | 13 +- tools/infer_rec.py | 7 +- tools/program.py | 17 ++- 5 files changed, 187 insertions(+), 29 deletions(-) diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index fe2d4c16..f2e24abd 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -136,7 +136,7 @@ class RecModel(object): else: labels = None loader = None - if self.char_type == "ch" and self.infer_img: + if self.char_type == "ch" and self.infer_img and self.loss_type != "srn": image_shape[-1] = -1 if self.tps != None: logger.info( @@ -172,16 +172,13 @@ class RecModel(object): self.max_text_length ], dtype="float32") - feed_list = [ - image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] labels = { 'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 } + return image, labels, loader def __call__(self, mode): @@ -218,8 +215,13 @@ class RecModel(object): if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) if self.loss_type == "srn": - raise Exception( - "Warning! SRN does not support export model currently") + return [ + image, labels, { + 'decoded_out': decoded_out, + 'predicts': predict + } + ] + return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: predict = predicts['predict'] diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c81b4eb2..8f676294 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -26,6 +26,7 @@ import copy import numpy as np import math import time +import paddle.fluid as fluid from ppocr.utils.character import CharacterOps @@ -37,18 +38,22 @@ class TextRecognizer(object): self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm + self.text_len = args.max_text_length char_ops_params = { "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char, "max_text_length": args.max_text_length } - if self.rec_algorithm != "RARE": + if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]: char_ops_params['loss_type'] = 'ctc' self.loss_type = 'ctc' - else: + elif self.rec_algorithm == "RARE": char_ops_params['loss_type'] = 'attention' self.loss_type = 'attention' + elif self.rec_algorithm == "SRN": + char_ops_params['loss_type'] = 'srn' + self.loss_type = 'srn' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -71,6 +76,83 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length, + char_num): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, + img, + image_shape, + num_heads, + max_text_length, + char_ops=None): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + char_num = char_ops.get_char_num() + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -80,7 +162,7 @@ class TextRecognizer(object): # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - # rec_res = [] + #rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num predict_time = 0 @@ -94,16 +176,52 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - norm_img_batch = np.concatenate(norm_img_batch) - norm_img_batch = norm_img_batch.copy() + if self.loss_type != "srn": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, + 25, self.char_ops) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) + + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + + gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list) + + gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list) + starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.zero_copy_run() + + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) + encoder_word_pos_list = fluid.core.PaddleTensor( + encoder_word_pos_list) + gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, gsrm_word_pos_list + ] + + self.predictor.run(inputs) if self.loss_type == "ctc": rec_idx_batch = self.output_tensors[0].copy_to_cpu() @@ -128,6 +246,26 @@ class TextRecognizer(object): score = np.mean(probs[valid_ind, ind[valid_ind]]) # rec_res.append([preds_text, score]) rec_res[indices[beg_img_no + rno]] = [preds_text, score] + elif self.loss_type == 'srn': + rec_idx_batch = self.output_tensors[0].copy_to_cpu() + probs = self.output_tensors[1].copy_to_cpu() + char_num = self.char_ops.get_char_num() + preds = rec_idx_batch.reshape(-1) + elapse = time.time() - starttime + predict_time += elapse + total_preds = preds.copy() + for ino in range(int(len(rec_idx_batch) / self.text_len)): + preds = total_preds[ino * self.text_len:(ino + 1) * + self.text_len] + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != int(char_num - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = self.char_ops.decode(preds) + + rec_res[indices[beg_img_no + ino]] = [preds_text, score] else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() @@ -162,6 +300,7 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) + try: rec_res, predict_time = text_recognizer(img_list) except Exception as e: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 392bc4df..e4e2a9fc 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -59,10 +59,10 @@ def parse_args(): parser.add_argument("--det_sast_polygon", type=bool, default=False) #params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_algorithm", type=str, default='SRN') parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256") + parser.add_argument("--rec_char_type", type=str, default='en') parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( @@ -107,10 +107,13 @@ def create_predictor(args, mode): # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") - config.switch_use_feed_fetch_ops(False) + #config.switch_use_feed_fetch_ops(False) + config.switch_use_feed_fetch_ops(True) predictor = create_paddle_predictor(config) input_names = predictor.get_input_names() - input_tensor = predictor.get_input_tensor(input_names[0]) + print(input_names) + for name in input_names: + input_tensor = predictor.get_input_tensor(name) output_names = predictor.get_output_names() output_tensors = [] for output_name in output_names: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index fd70cd66..e4b75e86 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -145,7 +145,7 @@ def main(): preds = preds.reshape(-1) probs = np.array(predict[1]) ind = np.argmax(probs, axis=1) - valid_ind = np.where(preds != int(char_num-1))[0] + valid_ind = np.where(preds != int(char_num - 1))[0] if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]]) @@ -162,7 +162,10 @@ def main(): fluid.io.save_inference_model( "./output/", - feeded_var_names=['image'], + feeded_var_names=[ + 'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2', 'gsrm_word_pos' + ], target_vars=target_var, executor=exe, main_program=eval_prog, diff --git a/tools/program.py b/tools/program.py index 6d8b9937..09552f41 100755 --- a/tools/program.py +++ b/tools/program.py @@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog): with fluid.unique_name.guard(): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) - image, outputs = model(mode='export') + loss_type = config['Global']['loss_type'] + if loss_type == "srn": + image, others, outputs = model(mode='export') + else: + image, outputs = model(mode='export') fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var = [outputs[name] for name in fetches_var_name] - feeded_var_names = [image.name] + if loss_type == "srn": + others_var_names = sorted([name for name in others.keys()]) + feeded_var_names = [image.name] + others_var_names + else: + feeded_var_names = [image.name] + target_vars = fetches_var return feeded_var_names, target_vars, fetches_var_name @@ -409,7 +418,9 @@ def preprocess(): check_gpu(use_gpu) alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] + assert alg in [ + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + ] if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global']) -- GitLab