提交 b566bcbb 编写于 作者: W WenmuZhou

add max_text_length to export model

上级 38fc1fae
...@@ -47,23 +47,25 @@ def main(): ...@@ -47,23 +47,25 @@ def main():
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) # init_model(config, model, logger)
model.eval() model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN": if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [ shape=[None, 1, 64, 256], dtype='float32'), [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 256, 1], shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec( dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1], shape=[None, max_text_length, 1], dtype="int64"),
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64") shape=[None, 8, max_text_length, max_text_length],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length],
dtype="int64")
] ]
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册