From 5db482cd5c8735a136a32c153f718c8750f726b5 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Thu, 9 Dec 2021 06:52:24 +0000 Subject: [PATCH] fix det slim export --- configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml | 1 + deploy/slim/quantization/export_model.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index ab484a44..bb6a1968 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -18,6 +18,7 @@ Global: Architecture: name: DistillationModel algorithm: Distillation + model_type: det Models: Teacher: freeze_params: true diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 7c3abb30..0cb86108 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -111,7 +111,7 @@ def main(): valid_dataloader = build_dataloader(config, 'Eval', device, logger) use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = None + model_type = config['Architecture'].get('model_type', None) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, use_srn) @@ -120,7 +120,7 @@ def main(): for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - infer_shape = [3, 32, 100] if model_type != "det" else [3, 640, 640] + infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640] save_path = config["Global"]["save_inference_dir"] -- GitLab