From b80d89691c1bd5078cdf35a4a1446fe6736db194 Mon Sep 17 00:00:00 2001 From: slf12 Date: Fri, 15 Nov 2019 15:15:50 +0800 Subject: [PATCH] add elementwise_add, pool2d --- paddleslim/quant/quanter.py | 65 ++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 16 deletions(-) mode change 100644 => 100755 paddleslim/quant/quanter.py diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py old mode 100644 new mode 100755 index 0db22772..eca57c9c --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core -WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] -ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] +WEIGHT_QUANTIZATION_TYPES = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' +] +ACTIVATION_QUANTIZATION_TYPES = [ + 'abs_max', 'range_abs_max', 'moving_average_abs_max' +] VALID_DTYPES = ['int8'] +TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] +QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] _quant_config_default = { # weight quantize type, default is 'abs_max' @@ -38,7 +46,8 @@ _quant_config_default = { # ops of name_scope in not_quant_pattern list, will not be quantized 'not_quant_pattern': ['skip_quant'], # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + 'quantize_op_types': + ['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'], # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 @@ -88,6 +97,12 @@ def _parse_configs(user_config): assert isinstance(configs['quantize_op_types'], list), \ "quantize_op_types must be a list" + for op_type in configs['quantize_op_types']: + assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( + op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ + now support op types are {}".format( + op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) + assert isinstance(configs['dtype'], str), \ "dtype must be a str." @@ -132,19 +147,37 @@ def quant_aware(program, place, config, scope=None, for_test=False): config = _parse_configs(config) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) - transform_pass = QuantizationTransformPass( - scope=scope, - place=place, - weight_bits=config['weight_bits'], - activation_bits=config['activation_bits'], - activation_quantize_type=config['activation_quantize_type'], - weight_quantize_type=config['weight_quantize_type'], - window_size=config['window_size'], - moving_rate=config['moving_rate'], - quantizable_op_type=config['quantize_op_types'], - skip_pattern=config['not_quant_pattern']) - - transform_pass.apply(main_graph) + transform_pass_ops = [] + quant_dequant_ops = [] + for op_type in config['quantize_op_types']: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + if len(transform_pass_ops) > 0: + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=transform_pass_ops, + skip_pattern=config['not_quant_pattern']) + + transform_pass.apply(main_graph) + + if len(quant_dequant_ops) > 0: + quant_dequant_pass = AddQuantDequantPass( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + quant_bits=config['activation_bits'], + skip_pattern=config['not_quant_pattern'], + quantizable_op_type=quant_dequant_ops) + quant_dequant_pass.apply(main_graph) if for_test: quant_program = main_graph.to_program() -- GitLab