From a2e932c0fb037360507dc8436b937388aeb18510 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Mon, 26 Oct 2020 14:50:03 +0800 Subject: [PATCH] add cls model quant code --- deploy/slim/quantization/export_model.py | 3 +++ deploy/slim/quantization/quant.py | 5 ++++- ppocr/modeling/architectures/cls_model.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index d0d08b30..5f4d91a8 100644 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -51,6 +51,7 @@ from paddleslim.quant import quant_aware, convert from paddle.fluid.layer_helper import LayerHelper from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run def main(): @@ -105,6 +106,8 @@ def main(): if alg_type == 'det': final_metrics = eval_det_run(exe, config, quant_info_dict, "eval") + elif alg_type == 'cls': + final_metrics = eval_cls_run(exe, quant_info_dict) else: final_metrics = eval_rec_run(exe, config, quant_info_dict, "eval") print(final_metrics) diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index b1003ca9..e75e84be 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -178,9 +178,12 @@ def main(): if train_alg_type == 'det': program.train_eval_det_run( config, exe, train_info_dict, eval_info_dict, is_slim="quant") - else: + elif train_alg_type == 'rec': program.train_eval_rec_run( config, exe, train_info_dict, eval_info_dict, is_slim="quant") + else: + program.train_eval_cls_run( + config, exe, train_info_dict, eval_info_dict, is_slim="quant") if __name__ == '__main__': diff --git a/ppocr/modeling/architectures/cls_model.py b/ppocr/modeling/architectures/cls_model.py index ad3ad0e7..30f3661b 100755 --- a/ppocr/modeling/architectures/cls_model.py +++ b/ppocr/modeling/architectures/cls_model.py @@ -65,6 +65,7 @@ class ClsModel(object): labels = None loader = None image = fluid.data(name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False return image, labels, loader def __call__(self, mode): -- GitLab