未验证 提交 9d8730f6 编写于 作者: L Liufang Sang 提交者: GitHub

add out scale for quantization (#272)

* add out scale for quantization

* update quant_aware unittest

* update quant_aware unittest
上级 f3891c1c
......@@ -25,6 +25,6 @@ try:
except Exception as e:
_logger.warning(
"If you want to use training-aware and post-training quantization, "
"please use Paddle >= 1.7.0 or develop version")
"please use Paddle >= 2.0.0 or develop version")
from .quant_embedding import quant_embedding
......@@ -24,6 +24,8 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
......@@ -220,6 +222,10 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
if for_test:
quant_program = main_graph.to_program()
else:
......@@ -363,9 +369,11 @@ def convert(program, place, config=None, scope=None, save_int8=False):
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
......@@ -379,10 +387,7 @@ def convert(program, place, config=None, scope=None, save_int8=False):
freezed_program = test_graph.to_program()
if save_int8:
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(),
place=place,
quantizable_op_type=support_op_types)
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8
......
......@@ -132,7 +132,7 @@ class TestQuantAwareCase2(unittest.TestCase):
def test(program):
iter = 0
result = [[], [], []]
for data in train_reader():
for data in eval_reader():
cost, top1, top5 = exe.run(
program,
feed=feeder.feed(data),
......@@ -161,7 +161,8 @@ class TestQuantAwareCase2(unittest.TestCase):
main_prog, place, config, for_test=False)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
train(quant_train_prog)
quant_eval_prog = convert(quant_eval_prog, place, config)
quant_eval_prog, int8_prog = convert(
quant_eval_prog, place, config, save_int8=True)
top1_2, top5_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册