提交 ab0acf78 编写于 作者: T tink2123

polish code

上级 fe8ce9af
...@@ -98,13 +98,15 @@ class RecModel(object): ...@@ -98,13 +98,15 @@ class RecModel(object):
shape=[ shape=[
-1, self.num_heads, self.max_text_length, -1, self.num_heads, self.max_text_length,
self.max_text_length self.max_text_length
]) ],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data( gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2", name="gsrm_slf_attn_bias2",
shape=[ shape=[
-1, self.num_heads, self.max_text_length, -1, self.num_heads, self.max_text_length,
self.max_text_length self.max_text_length
]) ],
dtype="float32")
lbl_weight = fluid.layers.data( lbl_weight = fluid.layers.data(
name="lbl_weight", shape=[-1, 1], dtype='int64') name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data( label = fluid.data(
...@@ -161,13 +163,15 @@ class RecModel(object): ...@@ -161,13 +163,15 @@ class RecModel(object):
shape=[ shape=[
-1, self.num_heads, self.max_text_length, -1, self.num_heads, self.max_text_length,
self.max_text_length self.max_text_length
]) ],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data( gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2", name="gsrm_slf_attn_bias2",
shape=[ shape=[
-1, self.num_heads, self.max_text_length, -1, self.num_heads, self.max_text_length,
self.max_text_length self.max_text_length
]) ],
dtype="float32")
feed_list = [ feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2 gsrm_slf_attn_bias2
...@@ -214,7 +218,7 @@ class RecModel(object): ...@@ -214,7 +218,7 @@ class RecModel(object):
if self.loss_type == "ctc": if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
if self.loss_type == "srn": if self.loss_type == "srn":
logger.infor( raise Exception(
"Warning! SRN does not support export model currently") "Warning! SRN does not support export model currently")
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
......
...@@ -26,12 +26,12 @@ class CharacterOps(object): ...@@ -26,12 +26,12 @@ class CharacterOps(object):
self.character_type = config['character_type'] self.character_type = config['character_type']
self.loss_type = config['loss_type'] self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length'] self.max_text_len = config['max_text_length']
if self.loss_type == "srn" and self.character_type == "ch":
raise Exception("SRN can only support in character_type == en")
if self.character_type == "en": if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif self.character_type == "ch": elif self.character_type == "ch":
if self.loss_type == "srn":
raise Exception("SRN can only support in character_type == en")
character_dict_path = config['character_dict_path'] character_dict_path = config['character_dict_path']
add_space = False add_space = False
if 'use_space_char' in config: if 'use_space_char' in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册