未验证 提交 b0d1dca6 编写于 作者: Z zhoujun 提交者: GitHub

fix starnet export (#1850)

* fix starnet export

* fix bug

* add note
上级 25bf9229
...@@ -52,9 +52,16 @@ def main(): ...@@ -52,9 +52,16 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册