未验证 提交 a8cfdd3a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix PACT quant_aware abnormal training accuracy (#1111)

上级 0bb77724
......@@ -95,6 +95,31 @@ _quant_config_default = {
}
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForTrainingPassV2(OutScaleForTrainingPass):
def __init__(self, scope=None, place=None, moving_rate=0.9):
OutScaleForTrainingPass.__init__(
self, scope=scope, place=place, moving_rate=moving_rate)
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForInferencePassV2(OutScaleForInferencePass):
def __init__(self, scope=None):
OutScaleForInferencePass.__init__(self, scope=scope)
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read()
......@@ -298,7 +323,7 @@ def quant_aware(program,
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass(
out_scale_training_pass = OutScaleForTrainingPassV2(
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
......@@ -509,7 +534,7 @@ def convert(program,
quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph)
else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册