From 11814d1c283aaa2816dc4e18250da230e83fb0e4 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 29 Aug 2022 17:55:50 +0800 Subject: [PATCH] fix fake quant demo (#1397) --- paddleslim/auto_compression/utils/fake_ptq.py | 95 ++++++++++++++----- 1 file changed, 71 insertions(+), 24 deletions(-) diff --git a/paddleslim/auto_compression/utils/fake_ptq.py b/paddleslim/auto_compression/utils/fake_ptq.py index e86dd848..9e506c8f 100644 --- a/paddleslim/auto_compression/utils/fake_ptq.py +++ b/paddleslim/auto_compression/utils/fake_ptq.py @@ -2,7 +2,7 @@ import os import paddle from paddle.fluid.framework import IrGraph from paddle.framework import core -from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass try: from paddle.fluid.contrib.slim.quantization import utils @@ -23,7 +23,8 @@ def post_quant_fake(executor, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, activation_bits=8, - weight_bits=8): + weight_bits=8, + onnx_format=False): """ Utilizing post training quantization methon to quantize the FP32 model, and it not uses calibrate data and the fake model cannot be used in practice. @@ -67,14 +68,24 @@ def post_quant_fake(executor, for op_type in _weight_supported_quantizable_op_type: if op_type in _quantizable_op_type: major_quantizable_op_types.append(op_type) - transform_pass = QuantizationTransformPass( - scope=_scope, - place=_place, - weight_bits=weight_bits, - activation_bits=activation_bits, - activation_quantize_type=activation_quantize_type, - weight_quantize_type=weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + if onnx_format: + transform_pass = QuantizationTransformPassV2( + scope=_scope, + place=_place, + weight_bits=weight_bits, + activation_bits=activation_bits, + activation_quantize_type=activation_quantize_type, + weight_quantize_type=weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) + else: + transform_pass = QuantizationTransformPass( + scope=_scope, + place=_place, + weight_bits=weight_bits, + activation_bits=activation_bits, + activation_quantize_type=activation_quantize_type, + weight_quantize_type=weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) for sub_graph in graph.all_sub_graphs(): # Insert fake_quant/fake_dequantize op must in test graph, so @@ -87,30 +98,66 @@ def post_quant_fake(executor, for op_type in _act_supported_quantizable_op_type: if op_type in _quantizable_op_type: minor_quantizable_op_types.append(op_type) - add_quant_dequant_pass = AddQuantDequantPass( - scope=_scope, - place=_place, - quantizable_op_type=minor_quantizable_op_types) + if onnx_format: + add_quant_dequant_pass = AddQuantDequantPassV2( + scope=_scope, + place=_place, + quantizable_op_type=minor_quantizable_op_types, + is_full_quantized=is_full_quantize) + else: + add_quant_dequant_pass = AddQuantDequantPass( + scope=_scope, + place=_place, + quantizable_op_type=minor_quantizable_op_types) for sub_graph in graph.all_sub_graphs(): sub_graph._for_test = True add_quant_dequant_pass.apply(sub_graph) # apply QuantizationFreezePass, and obtain the final quant model - freeze_pass = QuantizationFreezePass( - scope=_scope, - place=_place, - weight_bits=weight_bits, - activation_bits=activation_bits, - weight_quantize_type=weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + if onnx_format: + quant_weight_pass = QuantWeightPass(_scope, _place) + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + quant_weight_pass.apply(sub_graph) + else: + freeze_pass = QuantizationFreezePass( + scope=_scope, + place=_place, + weight_bits=weight_bits, + activation_bits=activation_bits, + weight_quantize_type=weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) - for sub_graph in graph.all_sub_graphs(): - sub_graph._for_test = True - freeze_pass.apply(sub_graph) + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + freeze_pass.apply(sub_graph) _program = graph.to_program() + def save_info(op_node, out_var_name, out_info_name, quantized_type): + op_node._set_attr(out_info_name, 0.001) + op_node._set_attr("with_quant_attr", True) + if op_node.type in _quantizable_op_type: + op._set_attr("quantization_type", quantized_type) + + def analysis_and_save_info(op_node, out_var_name): + argname_index = utils._get_output_name_index(op_node, out_var_name) + assert argname_index is not None, \ + out_var_name + " is not the output of the op" + + save_info(op_node, out_var_name, "out_threshold", "post_avg") + save_info(op_node, out_var_name, + argname_index[0] + str(argname_index[1]) + "_threshold", + "post_avg") + + for block_id in range(len(_program.blocks)): + for op in _program.blocks[block_id].ops: + if op.type in (_quantizable_op_type + utils._out_scale_op_list): + out_var_names = utils._get_op_output_var_names(op) + for var_name in out_var_names: + analysis_and_save_info(op, var_name) + feed_vars = [_program.global_block().var(name) for name in _feed_list] model_name = model_filename.split('.')[ 0] if model_filename is not None else 'model' -- GitLab