diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index d0d08b300066044d3088f669045e0536006c3140..5f4d91a8eea2d4f63334be571620aa0e1c00fabd 100644 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -51,6 +51,7 @@ from paddleslim.quant import quant_aware, convert from paddle.fluid.layer_helper import LayerHelper from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run def main(): @@ -105,6 +106,8 @@ def main(): if alg_type == 'det': final_metrics = eval_det_run(exe, config, quant_info_dict, "eval") + elif alg_type == 'cls': + final_metrics = eval_cls_run(exe, quant_info_dict) else: final_metrics = eval_rec_run(exe, config, quant_info_dict, "eval") print(final_metrics) diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index b1003ca9c564d01af8eb547e31b366f5833c2a07..e75e84be86867274368b6a0427b3d02d2524e020 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -178,9 +178,12 @@ def main(): if train_alg_type == 'det': program.train_eval_det_run( config, exe, train_info_dict, eval_info_dict, is_slim="quant") - else: + elif train_alg_type == 'rec': program.train_eval_rec_run( config, exe, train_info_dict, eval_info_dict, is_slim="quant") + else: + program.train_eval_cls_run( + config, exe, train_info_dict, eval_info_dict, is_slim="quant") if __name__ == '__main__': diff --git a/ppocr/modeling/architectures/cls_model.py b/ppocr/modeling/architectures/cls_model.py index ad3ad0e7cf4010a14c70a700ed02d02ee1f1323b..30f3661b4c5ef4370b542f46dfaee2fdc1fc53e5 100755 --- a/ppocr/modeling/architectures/cls_model.py +++ b/ppocr/modeling/architectures/cls_model.py @@ -65,6 +65,7 @@ class ClsModel(object): labels = None loader = None image = fluid.data(name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False return image, labels, loader def __call__(self, mode):