diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index ab484a44833a405513d7f2b4079a4da4c2e403c8..bb6a196864b6e9e7525f2b5217f0c90ea2ca05a4 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -18,6 +18,7 @@ Global: Architecture: name: DistillationModel algorithm: Distillation + model_type: det Models: Teacher: freeze_params: true diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index dddae923de223178665e3bfb55a2e7a8c0d5ba17..0cb86108d2275dc6ee1a74e118c27b94131975d3 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -111,7 +111,7 @@ def main(): valid_dataloader = build_dataloader(config, 'Eval', device, logger) use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = config['Architecture']['model_type'] + model_type = config['Architecture'].get('model_type', None) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, use_srn) @@ -120,8 +120,7 @@ def main(): for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - infer_shape = [3, 32, 100] if config['Architecture'][ - 'model_type'] != "det" else [3, 640, 640] + infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640] save_path = config["Global"]["save_inference_dir"] diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt index d62de35a9d40af10ae9f3e89da6f92f9b18c991c..ca52eeb1bc6a1853fa7015478fb9028d8dec71c3 100644 --- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o +norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o quant_train:null fpgm_train:null distill_train:null @@ -21,13 +21,13 @@ null:null null:null ## ===========================eval_params=========================== -eval:tools/eval.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o +eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o null:null ## ===========================infer_params=========================== Global.save_inference_dir:./output/ Global.pretrained_model: -norm_export:tools/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o +norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o quant_export:null fpgm_export:null distill_export:null