未验证 提交 9d8730f6 编写于 作者: L Liufang Sang 提交者: GitHub

add out scale for quantization (#272)

* add out scale for quantization

* update quant_aware unittest

* update quant_aware unittest
上级 f3891c1c
...@@ -25,6 +25,6 @@ try: ...@@ -25,6 +25,6 @@ try:
except Exception as e: except Exception as e:
_logger.warning( _logger.warning(
"If you want to use training-aware and post-training quantization, " "If you want to use training-aware and post-training quantization, "
"please use Paddle >= 1.7.0 or develop version") "please use Paddle >= 2.0.0 or develop version")
from .quant_embedding import quant_embedding from .quant_embedding import quant_embedding
...@@ -24,6 +24,8 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass ...@@ -24,6 +24,8 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.contrib.slim.quantization import WeightQuantization
...@@ -220,6 +222,10 @@ def quant_aware(program, place, config=None, scope=None, for_test=False): ...@@ -220,6 +222,10 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
quantizable_op_type=quant_dequant_ops) quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph) quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
else: else:
...@@ -363,9 +369,11 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -363,9 +369,11 @@ def convert(program, place, config=None, scope=None, save_int8=False):
assert isinstance(config, dict), "config must be dict" assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config) config = _parse_configs(config)
_logger.info("convert config {}".format(config)) _logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
...@@ -379,10 +387,7 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -379,10 +387,7 @@ def convert(program, place, config=None, scope=None, save_int8=False):
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
if save_int8: if save_int8:
convert_int8_pass = ConvertToInt8Pass( convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
scope=fluid.global_scope(),
place=place,
quantizable_op_type=support_op_types)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program() freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8 return freezed_program, freezed_program_int8
......
...@@ -132,7 +132,7 @@ class TestQuantAwareCase2(unittest.TestCase): ...@@ -132,7 +132,7 @@ class TestQuantAwareCase2(unittest.TestCase):
def test(program): def test(program):
iter = 0 iter = 0
result = [[], [], []] result = [[], [], []]
for data in train_reader(): for data in eval_reader():
cost, top1, top5 = exe.run( cost, top1, top5 = exe.run(
program, program,
feed=feeder.feed(data), feed=feeder.feed(data),
...@@ -161,7 +161,8 @@ class TestQuantAwareCase2(unittest.TestCase): ...@@ -161,7 +161,8 @@ class TestQuantAwareCase2(unittest.TestCase):
main_prog, place, config, for_test=False) main_prog, place, config, for_test=False)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
train(quant_train_prog) train(quant_train_prog)
quant_eval_prog = convert(quant_eval_prog, place, config) quant_eval_prog, int8_prog = convert(
quant_eval_prog, place, config, save_int8=True)
top1_2, top5_2 = test(quant_eval_prog) top1_2, top5_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册