提交 ab0acf78 编写于 作者: T tink2123

polish code

上级 fe8ce9af
......@@ -98,13 +98,15 @@ class RecModel(object):
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
lbl_weight = fluid.layers.data(
name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data(
......@@ -161,13 +163,15 @@ class RecModel(object):
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
......@@ -214,7 +218,7 @@ class RecModel(object):
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
logger.infor(
raise Exception(
"Warning! SRN does not support export model currently")
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
......
......@@ -26,12 +26,12 @@ class CharacterOps(object):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.loss_type == "srn" and self.character_type == "ch":
raise Exception("SRN can only support in character_type == en")
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
if self.loss_type == "srn":
raise Exception("SRN can only support in character_type == en")
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册