diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index fe2d4c16dce3882980fe2238ecc16c7c08a89792..f2e24abd9236ec163474361f0a89b775bc57448e 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -136,7 +136,7 @@ class RecModel(object): else: labels = None loader = None - if self.char_type == "ch" and self.infer_img: + if self.char_type == "ch" and self.infer_img and self.loss_type != "srn": image_shape[-1] = -1 if self.tps != None: logger.info( @@ -172,16 +172,13 @@ class RecModel(object): self.max_text_length ], dtype="float32") - feed_list = [ - image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] labels = { 'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 } + return image, labels, loader def __call__(self, mode): @@ -218,8 +215,13 @@ class RecModel(object): if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) if self.loss_type == "srn": - raise Exception( - "Warning! SRN does not support export model currently") + return [ + image, labels, { + 'decoded_out': decoded_out, + 'predicts': predict + } + ] + return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: predict = predicts['predict'] diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 6a379853a4a7d62cbffcbebbf09e2fb3e2207b27..06273e9f9e5b42a9ecc829c435662e9aabcdd224 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -40,6 +40,7 @@ class TextRecognizer(object): self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm + self.text_len = args.max_text_length self.use_zero_copy_run = args.use_zero_copy_run char_ops_params = { "character_type": args.rec_char_type, @@ -47,12 +48,15 @@ class TextRecognizer(object): "use_space_char": args.use_space_char, "max_text_length": args.max_text_length } - if self.rec_algorithm != "RARE": + if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]: char_ops_params['loss_type'] = 'ctc' self.loss_type = 'ctc' - else: + elif self.rec_algorithm == "RARE": char_ops_params['loss_type'] = 'attention' self.loss_type = 'attention' + elif self.rec_algorithm == "SRN": + char_ops_params['loss_type'] = 'srn' + self.loss_type = 'srn' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -75,6 +79,83 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length, + char_num): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, + img, + image_shape, + num_heads, + max_text_length, + char_ops=None): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + char_num = char_ops.get_char_num() + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -84,7 +165,7 @@ class TextRecognizer(object): # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - # rec_res = [] + #rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num predict_time = 0 @@ -98,20 +179,62 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - norm_img_batch = np.concatenate(norm_img_batch) + if self.loss_type != "srn": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, + 25, self.char_ops) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - if self.use_zero_copy_run: - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.zero_copy_run() - else: + + if self.loss_type == "srn": + starttime = time.time() + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + starttime = time.time() + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) - self.predictor.run([norm_img_batch]) + encoder_word_pos_list = fluid.core.PaddleTensor( + encoder_word_pos_list) + gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, encoder_word_pos_list, + gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, + gsrm_word_pos_list + ] + + self.predictor.run(inputs) + else: + starttime = time.time() + if self.use_zero_copy_run: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.zero_copy_run() + else: + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) + self.predictor.run([norm_img_batch]) if self.loss_type == "ctc": rec_idx_batch = self.output_tensors[0].copy_to_cpu() @@ -136,6 +259,26 @@ class TextRecognizer(object): score = np.mean(probs[valid_ind, ind[valid_ind]]) # rec_res.append([preds_text, score]) rec_res[indices[beg_img_no + rno]] = [preds_text, score] + elif self.loss_type == 'srn': + rec_idx_batch = self.output_tensors[0].copy_to_cpu() + probs = self.output_tensors[1].copy_to_cpu() + char_num = self.char_ops.get_char_num() + preds = rec_idx_batch.reshape(-1) + elapse = time.time() - starttime + predict_time += elapse + total_preds = preds.copy() + for ino in range(int(len(rec_idx_batch) / self.text_len)): + preds = total_preds[ino * self.text_len:(ino + 1) * + self.text_len] + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != int(char_num - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = self.char_ops.decode(preds) + + rec_res[indices[beg_img_no + ino]] = [preds_text, score] else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() @@ -170,6 +313,7 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) + try: rec_res, predict_time = text_recognizer(img_list) except Exception as e: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9d7ce13d37567ac80e194a6500a0f629ede4b1d4..8d66f8671815c53d00476033f772cb58069606c8 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -114,7 +114,8 @@ def create_predictor(args, mode): predictor = create_paddle_predictor(config) input_names = predictor.get_input_names() - input_tensor = predictor.get_input_tensor(input_names[0]) + for name in input_names: + input_tensor = predictor.get_input_tensor(name) output_names = predictor.get_output_names() output_tensors = [] for output_name in output_names: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index fd70cd66dccc2cb755efbd10c4d16c9f7a97146d..29fc5b40a890cd6e8fa3ca7d3f0999835555d9bd 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -145,7 +145,7 @@ def main(): preds = preds.reshape(-1) probs = np.array(predict[1]) ind = np.argmax(probs, axis=1) - valid_ind = np.where(preds != int(char_num-1))[0] + valid_ind = np.where(preds != int(char_num - 1))[0] if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]]) diff --git a/tools/program.py b/tools/program.py index 6d8b9937bff7e70f018b069467525669e7001aae..56f6b6993022d092095d7c3545f9ea8833900bc6 100755 --- a/tools/program.py +++ b/tools/program.py @@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog): with fluid.unique_name.guard(): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) - image, outputs = model(mode='export') + algorithm = config['Global']['algorithm'] + if algorithm == "SRN": + image, others, outputs = model(mode='export') + else: + image, outputs = model(mode='export') fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var = [outputs[name] for name in fetches_var_name] - feeded_var_names = [image.name] + if algorithm == "SRN": + others_var_names = sorted([name for name in others.keys()]) + feeded_var_names = [image.name] + others_var_names + else: + feeded_var_names = [image.name] + target_vars = fetches_var return feeded_var_names, target_vars, fetches_var_name @@ -409,7 +418,9 @@ def preprocess(): check_gpu(use_gpu) alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] + assert alg in [ + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + ] if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global'])