未验证 提交 74199a4a 编写于 作者: L Liufang Sang 提交者: GitHub

fix quantization for 1.8 (#300)

上级 4b46f61e
...@@ -20,11 +20,11 @@ from ..common import get_logger ...@@ -20,11 +20,11 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
try: try:
fluid.require_version('2.0.0') fluid.require_version('1.8.0')
from .quanter import quant_aware, quant_post, convert, quant_post_only_weight from .quanter import quant_aware, quant_post, convert, quant_post_only_weight
except Exception as e: except Exception as e:
_logger.warning( _logger.warning(
"If you want to use training-aware and post-training quantization, " "If you want to use training-aware and post-training quantization, "
"please use Paddle >= 2.0.0 or develop version") "please use Paddle >= 1.8.0 or develop version")
from .quant_embedding import quant_embedding from .quant_embedding import quant_embedding
...@@ -24,8 +24,6 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass ...@@ -24,8 +24,6 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.contrib.slim.quantization import WeightQuantization
...@@ -222,10 +220,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False): ...@@ -222,10 +220,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
quantizable_op_type=quant_dequant_ops) quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph) quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
else: else:
...@@ -371,9 +365,6 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -371,9 +365,6 @@ def convert(program, place, config=None, scope=None, save_int8=False):
_logger.info("convert config {}".format(config)) _logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册