From a8cfdd3adcdabcd4b694da3b199a78f0511d3d5e Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 17 May 2022 19:28:31 +0800 Subject: [PATCH] fix PACT quant_aware abnormal training accuracy (#1111) --- paddleslim/quant/quanter.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 9668ed34..2faa9a07 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -95,6 +95,31 @@ _quant_config_default = { } +# TODO: Hard-code, remove it when Paddle 2.3.1 +class OutScaleForTrainingPassV2(OutScaleForTrainingPass): + def __init__(self, scope=None, place=None, moving_rate=0.9): + OutScaleForTrainingPass.__init__( + self, scope=scope, place=place, moving_rate=moving_rate) + + def _scale_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@scale" % (var_name) + + +# TODO: Hard-code, remove it when Paddle 2.3.1 +class OutScaleForInferencePassV2(OutScaleForInferencePass): + def __init__(self, scope=None): + OutScaleForInferencePass.__init__(self, scope=scope) + + def _scale_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@scale" % (var_name) + + def load_dict(): with open(VARS_MAPPING_TABLE, 'r') as file: data = file.read() @@ -298,7 +323,7 @@ def quant_aware(program, quantizable_op_type=quant_dequant_ops) quant_dequant_pass.apply(main_graph) - out_scale_training_pass = OutScaleForTrainingPass( + out_scale_training_pass = OutScaleForTrainingPassV2( scope=scope, place=place, moving_rate=config['moving_rate']) out_scale_training_pass.apply(main_graph) @@ -509,7 +534,7 @@ def convert(program, quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass.apply(test_graph) else: - out_scale_infer_pass = OutScaleForInferencePass(scope=scope) + out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope) out_scale_infer_pass.apply(test_graph) # Freeze the graph after training by adjusting the quantize # operators' order for the inference. -- GitLab