From be3a164424d51c9be84578a52560f2c615ff1d98 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 3 Jun 2020 15:49:18 +0800 Subject: [PATCH] fix inference in tps --- configs/rec/rec_mv3_tps_bilstm_attn.yml | 4 +++- configs/rec/rec_mv3_tps_bilstm_ctc.yml | 1 + ppocr/data/rec/dataset_traversal.py | 5 ++++- ppocr/data/rec/img_tools.py | 9 +++++++-- ppocr/modeling/architectures/rec_model.py | 14 +++++++++++++- tools/infer/predict_rec.py | 4 ++-- 6 files changed, 30 insertions(+), 7 deletions(-) diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 792757b3..1787abc1 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -12,8 +12,10 @@ Global: test_batch_size_per_card: 256 image_shape: [3, 32, 100] max_text_length: 25 - character_type: en + character_type: ch + character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt loss_type: attention + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index f15ddab7..2892452b 100755 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -14,6 +14,7 @@ Global: max_text_length: 25 character_type: en loss_type: ctc + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index 2429dd20..8d6b4d1e 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -41,6 +41,8 @@ class LMDBReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + if "tps" in params: + self.tps = True if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] self.drop_last = params['drop_last'] @@ -109,7 +111,8 @@ class LMDBReader(object): norm_img = process_image( img=img, image_shape=self.image_shape, - char_ops=self.char_ops) + char_ops=self.char_ops, + tps=self.tps) yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index 303e1fe3..71efea17 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -92,11 +92,16 @@ def process_image(img, label=None, char_ops=None, loss_type=None, - max_text_length=None): + max_text_length=None, + tps=None): if char_ops.character_type == "en": norm_img = resize_norm_img(img, image_shape) else: - norm_img = resize_norm_img_chinese(img, image_shape) + if tps: + image_shape = [3, 32, 320] + norm_img = resize_norm_img(img, image_shape) + else: + norm_img = resize_norm_img_chinese(img, image_shape) norm_img = norm_img[np.newaxis, :] if label is not None: char_num = char_ops.get_char_num() diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index fdc0a641..d729f5d9 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -30,6 +30,7 @@ class RecModel(object): global_params = params['Global'] char_num = global_params['char_ops'].get_char_num() global_params['char_num'] = char_num + self.char_type = global_params['character_type'] if "TPS" in params: tps_params = deepcopy(params["TPS"]) tps_params.update(global_params) @@ -60,8 +61,8 @@ class RecModel(object): def create_feed(self, mode): image_shape = deepcopy(self.image_shape) image_shape.insert(0, -1) - image = fluid.data(name='image', shape=image_shape, dtype='float32') if mode == "train": + image = fluid.data(name='image', shape=image_shape, dtype='float32') if self.loss_type == "attention": label_in = fluid.data( name='label_in', @@ -86,6 +87,17 @@ class RecModel(object): use_double_buffer=True, iterable=False) else: + if self.char_type == "ch": + image_shape[-1] = -1 + if self.tps != None: + logger.info( + "WARNRNG!!!\n" + "TPS does not support variable shape in chinese!" + "We set default shape=[3,32,320], it may affect the inference effect" + ) + image_shape[-1] = 320 + image = fluid.data( + name='image', shape=image_shape, dtype='float32') labels = None loader = None return image, labels, loader diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 48553106..914fcde3 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -112,7 +112,7 @@ class TextRecognizer(object): else: preds = rec_idx_batch[rno, 1:end_pos[1]] score = np.mean(predict_batch[rno, 1:end_pos[1]]) - #todo: why index has 2 offset + #attenton index has 2 offset: beg and end preds = preds - 2 preds_text = self.char_ops.decode(preds) rec_res.append([preds_text, score]) @@ -138,7 +138,7 @@ if __name__ == "__main__": except: logger.info( "ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n" - "Please set --rec_image_shape=input_shape and --rec_char_type='ch' ") + "Please set --rec_image_shape=input_shape and --rec_char_type='en' ") exit() for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) -- GitLab