未验证 提交 a8cfdd3a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix PACT quant_aware abnormal training accuracy (#1111)

上级 0bb77724
...@@ -95,6 +95,31 @@ _quant_config_default = { ...@@ -95,6 +95,31 @@ _quant_config_default = {
} }
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForTrainingPassV2(OutScaleForTrainingPass):
def __init__(self, scope=None, place=None, moving_rate=0.9):
OutScaleForTrainingPass.__init__(
self, scope=scope, place=place, moving_rate=moving_rate)
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForInferencePassV2(OutScaleForInferencePass):
def __init__(self, scope=None):
OutScaleForInferencePass.__init__(self, scope=scope)
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
def load_dict(): def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file: with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read() data = file.read()
...@@ -298,7 +323,7 @@ def quant_aware(program, ...@@ -298,7 +323,7 @@ def quant_aware(program,
quantizable_op_type=quant_dequant_ops) quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph) quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass( out_scale_training_pass = OutScaleForTrainingPassV2(
scope=scope, place=place, moving_rate=config['moving_rate']) scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph) out_scale_training_pass.apply(main_graph)
...@@ -509,7 +534,7 @@ def convert(program, ...@@ -509,7 +534,7 @@ def convert(program,
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph) quant_weight_pass.apply(test_graph)
else: else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope) out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope)
out_scale_infer_pass.apply(test_graph) out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册