diff --git a/demo/quant/quant_aware/train.py b/demo/quant/quant_aware/train.py index 7cb088b440f4037203c837830b1b853ad99ab100..be13cc5868f8acf51d31985a3b0a815baa37fa82 100644 --- a/demo/quant/quant_aware/train.py +++ b/demo/quant/quant_aware/train.py @@ -12,7 +12,7 @@ sys.path.append(sys.path[0] + "/../../../") sys.path.append(sys.path[0] + "/../../") from paddleslim.common import get_logger from paddleslim.analysis import flops -from paddleslim.quant import quant_aware, quant_post, convert +from paddleslim.quant import quant_aware, convert import models from utility import add_arguments, print_arguments diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index ff748b06f3c1f084d78f69f712ac363f6ab3140d..b8505d9864aab54de08da183f971a738d2923264 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import logging import paddle import paddle.fluid as fluid from paddle.fluid.framework import IrGraph @@ -22,6 +23,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid import core +from ..common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + WEIGHT_QUANTIZATION_TYPES = [ 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max' @@ -40,10 +44,8 @@ _quant_config_default = { 'weight_bits': 8, # activation quantize bit num, default is 8 'activation_bits': 8, - # ops of name_scope in not_quant_pattern list, will not be quantized - 'not_quant_pattern': ['skip_quant'], - # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # ops of name_scope in not_quant_pattern , will not be quantized + 'not_quant_pattern': 'skip_quant', # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 @@ -84,12 +86,9 @@ def _parse_configs(user_config): assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ "activation_bits should be between 1 and 16." - assert isinstance(configs['not_quant_pattern'], list), \ + assert isinstance(configs['not_quant_pattern'], str), \ "not_quant_pattern must be a list" - assert isinstance(configs['quantize_op_types'], list), \ - "quantize_op_types must be a list" - assert isinstance(configs['dtype'], str), \ "dtype must be a str." @@ -102,11 +101,6 @@ def _parse_configs(user_config): assert isinstance(configs['moving_rate'], float), \ "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." - assert isinstance(configs['quant_weight_only'], bool), \ - "quant_weight_only must be bool value, if set quant_weight_only True, " \ - "then only quantize parameters of layers which need to be quantized, " \ - " and activations will not be quantized." - return configs @@ -142,7 +136,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False): weight_quantize_type=config['weight_quantize_type'], window_size=config['window_size'], moving_rate=config['moving_rate'], - quantizable_op_type=config['quantize_op_types'], skip_pattern=config['not_quant_pattern']) transform_pass.apply(main_graph) @@ -185,6 +178,8 @@ def convert(program, place, config=None, scope=None, save_int8=False): freeze_pass = QuantizationFreezePass( scope=scope, place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], weight_quantize_type=config['weight_quantize_type']) freeze_pass.apply(test_graph) freezed_program = test_graph.to_program()