diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index ab484a44833a405513d7f2b4079a4da4c2e403c8..bb6a196864b6e9e7525f2b5217f0c90ea2ca05a4 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -18,6 +18,7 @@ Global: Architecture: name: DistillationModel algorithm: Distillation + model_type: det Models: Teacher: freeze_params: true diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 7c3abb3086aaf3cebdf940dd0e018fc44a3c05e2..0cb86108d2275dc6ee1a74e118c27b94131975d3 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -111,7 +111,7 @@ def main(): valid_dataloader = build_dataloader(config, 'Eval', device, logger) use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = None + model_type = config['Architecture'].get('model_type', None) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, use_srn) @@ -120,7 +120,7 @@ def main(): for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - infer_shape = [3, 32, 100] if model_type != "det" else [3, 640, 640] + infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640] save_path = config["Global"]["save_inference_dir"]