You need to sign in or sign up before continuing.
提交 6a38af58 编写于 作者: W WenmuZhou

fix starnet export

上级 25bf9229
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) # init_model(config, model, logger)
model.eval() model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
...@@ -55,6 +55,10 @@ def main(): ...@@ -55,6 +55,10 @@ def main():
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] infer_shape = [3, 32, -1]
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册