From b0d1dca68800729511999b6dcee3a8ab2c4e4952 Mon Sep 17 00:00:00 2001 From: zhoujun Date: Thu, 28 Jan 2021 18:16:54 +0800 Subject: [PATCH] fix starnet export (#1850) * fix starnet export * fix bug * add note --- tools/export_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index b7d61a59..a9b9e7dd 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -52,9 +52,16 @@ def main(): save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - infer_shape = [3, -1, -1] + infer_shape = [3, -1, -1] if config['Architecture']['model_type'] == "rec": - infer_shape = [3, 32, -1] + infer_shape = [3, 32, -1] # for rec model, H must be 32 + if 'Transform' in config['Architecture'] and config['Architecture'][ + 'Transform'] is not None and config['Architecture'][ + 'Transform']['name'] == 'TPS': + logger.info( + 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' + ) + infer_shape[-1] = 100 model = to_static( model, -- GitLab