diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index f4e3eea22750fa6a50a33b3d70a2f7f88d865288..91f778ce97768bc8079affb0e974d10da6c195d4 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -98,13 +98,15 @@ class RecModel(object): shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") gsrm_slf_attn_bias2 = fluid.data( name="gsrm_slf_attn_bias2", shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") lbl_weight = fluid.layers.data( name="lbl_weight", shape=[-1, 1], dtype='int64') label = fluid.data( @@ -161,13 +163,15 @@ class RecModel(object): shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") gsrm_slf_attn_bias2 = fluid.data( name="gsrm_slf_attn_bias2", shape=[ -1, self.num_heads, self.max_text_length, self.max_text_length - ]) + ], + dtype="float32") feed_list = [ image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2 @@ -214,7 +218,7 @@ class RecModel(object): if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) if self.loss_type == "srn": - logger.infor( + raise Exception( "Warning! SRN does not support export model currently") return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index f27e1b85877fd36fe97d54cc529c313f86ff6787..2db0151e62ced347cfbda2d8be9d49e1e6b40690 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -26,12 +26,12 @@ class CharacterOps(object): self.character_type = config['character_type'] self.loss_type = config['loss_type'] self.max_text_len = config['max_text_length'] + if self.loss_type == "srn" and self.character_type == "ch": + raise Exception("SRN can only support in character_type == en") if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) elif self.character_type == "ch": - if self.loss_type == "srn": - raise Exception("SRN can only support in character_type == en") character_dict_path = config['character_dict_path'] add_space = False if 'use_space_char' in config: