提交 6a38af58 编写于 作者: W WenmuZhou

fix starnet export

上级 25bf9229
......@@ -47,14 +47,18 @@ def main():
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
# init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, -1, -1]
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1]
infer_shape = [3, 32, -1]
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
infer_shape[-1] = 100
model = to_static(
model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册