From ab0acf78b4fb2c23a685c5c3ce17ea78870118ec Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sun, 16 Aug 2020 17:09:17 +0800 Subject: [PATCH] polish code --- ppocr/modeling/architectures/rec_model.py | 14 +++++++++----- ppocr/utils/character.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index f4e3eea2..91f778ce 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -98,13 +98,15 @@ class RecModel(object): shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") gsrm_slf_attn_bias2 = fluid.data( name="gsrm_slf_attn_bias2", shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") lbl_weight = fluid.layers.data( name="lbl_weight", shape=[-1, 1], dtype='int64') label = fluid.data( @@ -161,13 +163,15 @@ class RecModel(object): shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") gsrm_slf_attn_bias2 = fluid.data( name="gsrm_slf_attn_bias2", shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") feed_list = [ image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2 @@ -214,7 +218,7 @@ class RecModel(object): if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) if self.loss_type == "srn": - logger.infor( + raise Exception( "Warning! SRN does not support export model currently") return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index f27e1b85..2db0151e 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -26,12 +26,12 @@ class CharacterOps(object): self.character_type = config['character_type'] self.loss_type = config['loss_type'] self.max_text_len = config['max_text_length'] + if self.loss_type == "srn" and self.character_type == "ch": + raise Exception("SRN can only support in character_type == en") if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) elif self.character_type == "ch": - if self.loss_type == "srn": - raise Exception("SRN can only support in character_type == en") character_dict_path = config['character_dict_path'] add_space = False if 'use_space_char' in config: -- GitLab