From 00b11a4a1e05f7d5637dc2864484add3e243ebb9 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Sat, 16 Nov 2019 07:41:46 +0800 Subject: [PATCH] Support more ops in post training quantization, test=develop (#21073) * Support more ops in post training quantization, and save the output scale in quantized op. * Update docs in post training quantization and qat --- .../post_training_quantization.py | 156 +++++++---- .../slim/quantization/quantization_pass.py | 253 ++++++++++++------ .../tests/test_post_training_quantization.py | 12 +- 3 files changed, 285 insertions(+), 136 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 59b77ea6a9b..d0188b9b93e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -23,6 +23,7 @@ from ....log_helper import get_logger from .quantization_pass import QuantizationTransformPass from .quantization_pass import QuantizationFreezePass from .quantization_pass import AddQuantDequantPass +from .quantization_pass import _op_real_in_out_name __all__ = ['PostTrainingQuantization'] @@ -39,10 +40,8 @@ class PostTrainingQuantization(object): batch_nums=None, scope=None, algo="KL", - quantizable_op_type=[ - "conv2d", "depthwise_conv2d", "mul", "pool2d", - "elementwise_add" - ]): + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + is_full_quantize=False): ''' The class utilizes post training quantization methon to quantize the fp32 model. It uses calibrate data to calculate the scale factor of @@ -66,7 +65,14 @@ class PostTrainingQuantization(object): abs_max methon to get the scale factor. Default is KL. quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", - "mul", "pool2d", "elementwise_add"]. + "mul"]. + is_full_quantized(bool, optional): If set is_full_quantized as True, + apply quantization to all supported quantizable op type. If set + is_full_quantized as False, only apply quantization to the op type + according to the input quantizable_op_type. + Returns: + None + Examples: .. code-block:: python import paddle.fluid as fluid @@ -98,14 +104,19 @@ class PostTrainingQuantization(object): self._batch_size = batch_size self._batch_nums = batch_nums self._scope = global_scope() if scope == None else scope - self._quantizable_op_type = quantizable_op_type self._algo = algo - supported_quantizable_op_type = [ - "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" - ] - for op_type in self._quantizable_op_type: - assert op_type in supported_quantizable_op_type, \ - op_type + " is not supported for quantization." + + supported_quantizable_op_type = \ + QuantizationTransformPass._supported_quantizable_op_type + \ + AddQuantDequantPass._supported_quantizable_op_type + if is_full_quantize: + self._quantizable_op_type = supported_quantizable_op_type + else: + self._quantizable_op_type = quantizable_op_type + for op_type in self._quantizable_op_type: + assert op_type in supported_quantizable_op_type + \ + AddQuantDequantPass._activation_type, \ + op_type + " is not supported for quantization." self._place = self._executor.place self._program = None @@ -113,6 +124,7 @@ class PostTrainingQuantization(object): self._fetch_list = None self._data_loader = None + self._op_real_in_out_name = _op_real_in_out_name self._bit_length = 8 self._quantized_weight_var_name = [] self._quantized_act_var_name = [] @@ -124,11 +136,13 @@ class PostTrainingQuantization(object): Quantize the fp32 model. Use calibrate data to calculate the scale factor of quantized variables, and inserts fake quant/dequant op to obtain the quantized model. - - Return: + + Args: + None + Returns: the program of quantized model. ''' - self._prepare() + self._preprocess() batch_id = 0 for data in self._data_loader(): @@ -136,7 +150,6 @@ class PostTrainingQuantization(object): feed=data, fetch_list=self._fetch_list) self._sample_data() - if batch_id % 5 == 0: _logger.info("run batch: " + str(batch_id)) batch_id += 1 @@ -144,9 +157,13 @@ class PostTrainingQuantization(object): break _logger.info("all run batch: " + str(batch_id)) + _logger.info("calculate scale factor ...") self._calculate_scale_factor() + + _logger.info("update the program ...") self._update_program() + self._save_output_scale() return self._program def save_quantized_model(self, save_model_path): @@ -155,7 +172,7 @@ class PostTrainingQuantization(object): Args: save_model_path(str): The path to save the quantized model - Return: + Returns: None ''' io.save_inference_model( @@ -165,7 +182,7 @@ class PostTrainingQuantization(object): executor=self._executor, main_program=self._program) - def _prepare(self): + def _preprocess(self): ''' Load model and set data loader, collect the variable names for sampling, and set activation variables to be persistable. @@ -183,14 +200,13 @@ class PostTrainingQuantization(object): drop_last=True, places=self._place) - #collect the variable names for sampling + # collect the variable names for sampling persistable_var_names = [] for var in self._program.list_vars(): if var.persistable: persistable_var_names.append(var.name) - block = self._program.global_block() - for op in block.ops: + for op in self._program.global_block().ops: op_type = op.type if op_type in self._quantizable_op_type: if op_type in ("conv2d", "depthwise_conv2d"): @@ -199,29 +215,30 @@ class PostTrainingQuantization(object): op.input("Filter")[0]) self._quantized_act_var_name.append(op.output("Output")[0]) elif op_type == "mul": - x_var_name = op.input("X")[0] - y_var_name = op.input("Y")[0] - if x_var_name not in persistable_var_names and \ - y_var_name not in persistable_var_names: + if self._is_input_all_not_persistable( + op, persistable_var_names): op._set_attr("skip_quant", True) - _logger.warning("A mul op skip quant for two " + _logger.warning("Skip quant a mul op for two " "input variables are not persistable") else: - self._quantized_act_var_name.append(x_var_name) - self._quantized_weight_var_name.append(y_var_name) + self._quantized_act_var_name.append(op.input("X")[0]) + self._quantized_weight_var_name.append(op.input("Y")[0]) self._quantized_act_var_name.append(op.output("Out")[0]) - elif op_type == "pool2d": - self._quantized_act_var_name.append(op.input("X")[0]) - elif op_type == "elementwise_add": - x_var_name = op.input("X")[0] - y_var_name = op.input("Y")[0] - if x_var_name not in persistable_var_names and \ - y_var_name not in persistable_var_names: - self._quantized_act_var_name.append(x_var_name) - self._quantized_act_var_name.append(y_var_name) - - # set activation variables to be persistable, - # so can obtain the tensor data in sample_data stage + else: + # process other quantizable op type, the input must all not persistable + if self._is_input_all_not_persistable( + op, persistable_var_names): + input_output_name_list = self._op_real_in_out_name[ + op_type] + for input_name in input_output_name_list[0]: + for var_name in op.input(input_name): + self._quantized_act_var_name.append(var_name) + for output_name in input_output_name_list[1]: + for var_name in op.output(output_name): + self._quantized_act_var_name.append(var_name) + + # set activation variables to be persistable, so can obtain + # the tensor data in sample_data for var in self._program.list_vars(): if var.name in self._quantized_act_var_name: var.persistable = True @@ -246,8 +263,7 @@ class PostTrainingQuantization(object): ''' Calculate the scale factor of quantized variables. ''' - _logger.info("calculate scale factor ...") - + # apply channel_wise_abs_max quantization for weights for var_name in self._quantized_weight_var_name: data = self._sampling_data[var_name] scale_factor_per_channel = [] @@ -257,6 +273,7 @@ class PostTrainingQuantization(object): self._quantized_var_scale_factor[ var_name] = scale_factor_per_channel + # apply kl quantization for activation for var_name in self._quantized_act_var_name: if self._algo == "KL": self._quantized_var_scale_factor[var_name] = \ @@ -269,8 +286,7 @@ class PostTrainingQuantization(object): ''' Insert fake_quantize/fake_dequantize op to the program. ''' - _logger.info("update the program ...") - + # reset quantized activation variable for var in self._program.list_vars(): if var.name in self._quantized_act_var_name: var.persistable = False @@ -278,10 +294,10 @@ class PostTrainingQuantization(object): # use QuantizationTransformPass to insert fake_quantize/fake_dequantize op graph = IrGraph(core.Graph(self._program.desc), for_test=True) - qtp_quantizable_op_type = [] - for op_type in ["conv2d", "depthwise_conv2d", "mul"]: + major_quantizable_op_types = [] + for op_type in QuantizationTransformPass._supported_quantizable_op_type: if op_type in self._quantizable_op_type: - qtp_quantizable_op_type.append(op_type) + major_quantizable_op_types.append(op_type) transform_pass = QuantizationTransformPass( scope=self._scope, place=self._place, @@ -289,18 +305,18 @@ class PostTrainingQuantization(object): activation_bits=self._bit_length, activation_quantize_type='moving_average_abs_max', weight_quantize_type='channel_wise_abs_max', - quantizable_op_type=qtp_quantizable_op_type) + quantizable_op_type=major_quantizable_op_types) transform_pass.apply(graph) # use AddQuantDequantPass to insert fake_quant_dequant op - aqdp_quantizable_op_type = [] - for op_type in ["pool2d", "elementwise_add"]: + minor_quantizable_op_types = [] + for op_type in AddQuantDequantPass._supported_quantizable_op_type: if op_type in self._quantizable_op_type: - aqdp_quantizable_op_type.append(op_type) + minor_quantizable_op_types.append(op_type) add_quant_dequant_pass = AddQuantDequantPass( scope=self._scope, place=self._place, - quantizable_op_type=aqdp_quantizable_op_type) + quantizable_op_type=minor_quantizable_op_types) add_quant_dequant_pass.apply(graph) # save scale factor to scale var node @@ -319,10 +335,25 @@ class PostTrainingQuantization(object): weight_bits=self._bit_length, activation_bits=self._bit_length, weight_quantize_type='channel_wise_abs_max', - quantizable_op_type=qtp_quantizable_op_type) + quantizable_op_type=major_quantizable_op_types) freeze_pass.apply(graph) self._program = graph.to_program() + def _save_output_scale(self): + ''' + Save output scale to the quantized op. + ''' + output_scale_name = "output_scale" + for op in self._program.global_block().ops: + if op.type in self._quantizable_op_type: + output_name_list = self._op_real_in_out_name[op.type][1] + for output_name in output_name_list: + output_var_name = op.output(output_name)[0] + if output_var_name in self._quantized_var_scale_factor: + op._set_attr( + output_scale_name, + self._quantized_var_scale_factor[output_var_name]) + def _load_var_value(self, var_name): ''' Load variable value from scope @@ -331,7 +362,7 @@ class PostTrainingQuantization(object): def _set_var_node_value(self, var_node_name, np_value): ''' - Set the value of var node by name, if the node is not exits, + Set the value of var node by name, if the node exits, ''' assert isinstance(np_value, np.ndarray), \ 'The type of value should be numpy array.' @@ -340,6 +371,19 @@ class PostTrainingQuantization(object): tensor = var_node.get_tensor() tensor.set(np_value, self._place) + def _is_input_all_not_persistable(self, op, persistable_var_names): + ''' + Analyze the real inputs of the op are all not persistable. + ''' + is_input_all_not_persistable = True + input_name_list = self._op_real_in_out_name[op.type][0] + for input_name in input_name_list: + for var_name in op.input(input_name): + if var_name in persistable_var_names: + is_input_all_not_persistable = False + break + return is_input_all_not_persistable + def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): ''' Using the KL-divergenc method to get the more precise scaling factor. @@ -441,8 +485,8 @@ class PostTrainingQuantization(object): tmp_sum2 += 0 else: if q_idx == 0: - print("Fatal error!, idx = " + str(idx) + - " qindex = 0! p_idx = " + str(p_idx)) + _logger.error("Fatal error!, idx = " + str(idx) + + " qindex = 0! p_idx = " + str(p_idx)) tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) return (tmp_sum1 - tmp_sum2) / P_sum diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 4e9924260b4..ab3ca4cc102 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -41,6 +41,40 @@ _out_scale_op_list = [ "dropout", "split", "prelu", "conv2d_transpose", "leaky_relu" ] +# list op real input and output names, to avoid processing input such as AxisTensor. +_op_real_in_out_name = { + "conv2d": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input"], ["Output"]], + "mul": [["X", "Y"], ["Out"]], + "pool2d": [["X"], ["Out"]], + "elementwise_add": [["X", "Y"], ["Out"]], + "concat": [["X"], ["Out"]], + "softmax": [["X"], ["Out"]], + "argmax": [["X"], ["Out"]], + "transpose": [["X"], ["Out"]], + "equal": [["X", "Y"], ["Out"]], + "gather": [["X"], ["Out"]], + "greater_equal": [["X", "Y"], ["Out"]], + "greater_than": [["X", "Y"], ["Out"]], + "less_equal": [["X", "Y"], ["Out"]], + "less_than": [["X", "Y"], ["Out"]], + "mean": [["X"], ["Out"]], + "not_equal": [["X", "Y"], ["Out"]], + "reshape": [["X"], ["Out"]], + "reshape2": [["X"], ["Out"]], + "bilinear_interp": [["X"], ["Out"]], + "nearest_interp": [["X"], ["Out"]], + "trilinear_interp": [["X"], ["Out"]], + "slice": [["Input"], ["Out"]], + "squeeze": [["X"], ["Out"]], + "elementwise_sub": [["X", "Y"], ["Out"]], + "relu": [["X"], ["Out"]], + "relu6": [["X"], ["Out"]], + "leaky_relu": [["X"], ["Out"]], + "tanh": [["X"], ["Out"]], + "swish": [["X"], ["Out"]], +} + def _init_var_node(var_node, value, scope, place): assert isinstance(value, @@ -54,6 +88,8 @@ def _init_var_node(var_node, value, scope, place): class QuantizationTransformPass(object): + _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + def __init__(self, scope=None, place=None, @@ -75,25 +111,27 @@ class QuantizationTransformPass(object): initialize these new parameters. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new parameters described above. - weight_bits (int): quantization bit number for weights, + weight_bits(int): quantization bit number for weights, the bias is not quantized. - activation_bits (int): quantization bit number for activation. - activation_quantize_type (str): quantization type for activation, + activation_bits(int): quantization bit number for activation. + activation_quantize_type(str): quantization type for activation, now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. If use 'abs_max' mode, the quantization scale will be calculated dynamically each step in both training and testing period. If use 'range_abs_max', a static quantization scale will be calculated during training and used in inference. - weight_quantize_type (str): quantization type for weights, + weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. - window_size (int): the window size for 'range_abs_max' quantization. + window_size(int): the window size for 'range_abs_max' quantization. + moving_rate(float): the param for 'moving_average_abs_max' quantization. skip_pattern(str): The user-defined quantization skip pattern, which will be presented in the name scope of an op. When the skip pattern is detected in an op's name scope, the corresponding op will not be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized. - Default is ["conv2d", "depthwise_conv2d", "mul"]. + Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in + QuantizationFreezePass and ConvertToInt8Pass must be the same as this. Examples: .. code-block:: python @@ -139,9 +177,8 @@ class QuantizationTransformPass(object): self._moving_rate = moving_rate self._quantizable_ops = quantizable_op_type - supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] for op in self._quantizable_ops: - assert op in supported_quantizable_ops, \ + assert op in QuantizationTransformPass._supported_quantizable_op_type, \ op + " is not supported for quantization." self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ @@ -158,6 +195,8 @@ class QuantizationTransformPass(object): Args: graph(IrGraph): the applied graph. + Returns: + None """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' @@ -589,24 +628,8 @@ class QuantizationTransformPass(object): class QuantizationFreezePass(object): - """ - The freeze pass is used to adjust the quantize operator order, for example: - 1) `activation -> quant -> dequant -> conv2d` will be freezed into - `activation -> quant -> conv2d -> dequant` - 2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`, - and weight will be sacled offline. - - Args: - scope(fluid.Scope): scope is used to get the weight tensor values. - place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. - weight_bits (int): quantization bit number for weights. - activation_bits (int): quantization bit number for activation. - weight_quantize_type (str): quantization type for weights, support 'abs_max' and - 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, - since weights are fixed once the model is well trained. - quantizable_op_type(list[str]): List the type of ops that will be quantized. - Default is ["conv2d", "depthwise_conv2d", "mul"]. - """ + _supported_quantizable_op_type = \ + QuantizationTransformPass._supported_quantizable_op_type def __init__(self, scope, @@ -615,6 +638,25 @@ class QuantizationFreezePass(object): activation_bits=8, weight_quantize_type='abs_max', quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): + """ + The freeze pass is used to adjust the quantize operator order, for example: + 1) `activation -> quant -> dequant -> conv2d` will be freezed into + `activation -> quant -> conv2d -> dequant` + 2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`, + and weight will be sacled offline. + + Args: + scope(fluid.Scope): scope is used to get the weight tensor values. + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. + weight_bits(int): quantization bit number for weights. + activation_bits(int): quantization bit number for activation. + weight_quantize_type(str): quantization type for weights, support 'abs_max' and + 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, + since weights are fixed once the model is well trained. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in + QuantizationTransformPass and ConvertToInt8Pass must be the same as this. + """ assert scope is not None, \ 'The scope cannot be set None.' assert place is not None, \ @@ -625,9 +667,8 @@ class QuantizationFreezePass(object): self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type self._quantizable_ops = quantizable_op_type - supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] for op in self._quantizable_ops: - assert op in supported_quantizable_ops, \ + assert op in QuantizationFreezePass._supported_quantizable_op_type, \ op + " is not supported for quantization." self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._fake_quant_op_names = _fake_quant_op_list @@ -642,6 +683,8 @@ class QuantizationFreezePass(object): Args: graph(IrGraph): the applied graph. + Returns: + None """ persistable_vars = [p.name() for p in graph.all_persistable_nodes()] ops = graph.all_op_nodes() @@ -895,21 +938,24 @@ class QuantizationFreezePass(object): class ConvertToInt8Pass(object): - """ - Convert the weights into int8_t type. - - Args: - scope(fluid.Scope): scope is used to get the weight tensor values. - place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the - 8bits weight tensors. - quantizable_op_type(list[str]): List the type of ops that will be quantized. - Default is ["conv2d", "depthwise_conv2d", "mul"]. - """ + _supported_quantizable_op_type = \ + QuantizationTransformPass._supported_quantizable_op_type def __init__(self, scope, place, quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): + """ + Convert the weights into int8_t type. + + Args: + scope(fluid.Scope): scope is used to get the weight tensor values. + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the + 8bits weight tensors. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in + QuantizationTransformPass and QuantizationFreezePass must be the same as this. + """ assert scope is not None, \ 'The scope cannot be set None.' assert place is not None, \ @@ -917,9 +963,8 @@ class ConvertToInt8Pass(object): self._scope = scope self._place = place self._quantizable_ops = quantizable_op_type - supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] for op in self._quantizable_ops: - assert op in supported_quantizable_ops, \ + assert op in ConvertToInt8Pass._supported_quantizable_op_type, \ op + " is not supported for quantization." def apply(self, graph): @@ -929,6 +974,8 @@ class ConvertToInt8Pass(object): Args: graph(IrGraph): the applied graph. + Returns: + None """ persistable_vars = [p.name() for p in graph.all_persistable_nodes()] ops = graph.all_op_nodes() @@ -993,11 +1040,10 @@ class ConvertToInt8Pass(object): class TransformForMobilePass(object): - """ - This pass is used to convert the freezed graph for paddle-mobile execution. - """ - def __init__(self): + """ + This pass is used to convert the freezed graph for paddle-mobile execution. + """ self._fake_quant_op_names = _fake_quant_op_list self._fake_dequant_op_names = _fake_dequant_op_list @@ -1009,6 +1055,8 @@ class TransformForMobilePass(object): Args: graph(IrGraph): the graph will be transformed. + Returns: + None """ ops = graph.all_op_nodes() for op_node in ops: @@ -1183,16 +1231,45 @@ class ScaleForInferencePass(object): class AddQuantDequantPass(object): + _supported_quantizable_op_type = [ + "pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose", + "equal", "gather", "greater_equal", "greater_than", "less_equal", + "less_than", "mean", "not_equal", "reshape", "reshape2", + "bilinear_interp", "nearest_interp", "trilinear_interp", "slice", + "squeeze", "elementwise_sub" + ] + _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"] + def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8, skip_pattern='skip_quant', - quantizable_op_type=["elementwise_add", "pool2d"]): + quantizable_op_type=["elementwise_add", "pool2d", "concat"], + is_full_quantized=False): """ - This pass is used to add quant_dequant op for some ops, such as the - 'elementwise_add' and 'pool2d' op. + This pass add quant_dequant op for some ops, of which all the inputs must be + not persistable. + The input scales can be obtained from the quant_dequant op. + + Args: + scope(fluid.Scope): The scope is used to initialize these new parameters. + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new + parameters described above. + moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max' + quantization. Default is 0.9. + quant_bits(int, optional): quantization bit number for activation. Default is 8. + skip_pattern(str, optional): The user-defined quantization skip pattern, which + will be presented in the name scope of an op. When the skip pattern is + detected in an op's name scope, the corresponding op will not be quantized. + Default is 'skip_quant'. + quantizable_op_type(list[str], optional): List the type of ops that will be + quantized. Default is ["elementwise_add", "pool2d", "concat"]. + is_full_quantized(bool, optional): If set is_full_quantized as True, apply + quantization to all supported quantizable op type. If set is_full_quantized + as False, only apply quantization to the op type according to the input + quantizable_op_type. """ self._scope = scope self._place = place @@ -1200,60 +1277,67 @@ class AddQuantDequantPass(object): self._quant_bits = quant_bits self._is_test = None self._skip_pattern = skip_pattern - self._quantizable_op_type = quantizable_op_type + + if is_full_quantized: + self._quantizable_op_type = \ + AddQuantDequantPass._supported_quantizable_op_type + else: + self._quantizable_op_type = quantizable_op_type + for op_type in quantizable_op_type: + assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \ + AddQuantDequantPass._activation_type, \ + op_type + " is not supported for quantization." self._quantizable_grad_op_type = [ '%s_grad' % (op) for op in self._quantizable_op_type ] - supported_quantizable_op_type = ["elementwise_add", "pool2d"] - for op_type in quantizable_op_type: - assert op_type in supported_quantizable_op_type, \ - op_type + " is not supported for quantization." + assert self._scope != None, "scope must not be None." + assert self._place != None, "place must not be None." def apply(self, graph): """ - Add quant_dequant before some ops, such as the 'elementwise_add' - and 'pool2d' op. + Add quant_dequant before some ops, such as the 'elementwise_add', + 'pool2d' and 'concat' op. + Args: graph(IrGraph): the target graph. + Returns: + None """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' self._is_test = graph.is_test() dequantized_vars_map = collections.OrderedDict() - ops = graph.all_op_nodes() - for op_node in ops: + # Forward stage, insert quant_dequant op + all_op_nodes = graph.all_op_nodes() + for op_node in all_op_nodes: if op_node.name() in self._quantizable_op_type: if isinstance(self._skip_pattern, str) and \ op_node.op().has_attr("op_namescope") and \ op_node.op().attr("op_namescope").find(self._skip_pattern) != -1: continue - in_nodes_all_not_persistable = True - for input_name in op_node.input_arg_names(): - in_node = graph._find_node_by_name(op_node.inputs, - input_name) - in_nodes_all_not_persistable = ( - in_nodes_all_not_persistable and - not in_node.persistable()) - if not in_nodes_all_not_persistable: + if not self._is_input_all_not_persistable(graph, op_node): continue - input_names = op_node.input_arg_names() - for input_name in input_names: - in_node = graph._find_node_by_name(op_node.inputs, - input_name) - if input_name in dequantized_vars_map: - quant_var_node = dequantized_vars_map[input_name] - else: - quant_var_node, scale_var_node = \ - self._inser_quant_dequant_moving_average_abs_max_op( - graph, in_node, self._quant_bits) - dequantized_vars_map[input_name] = quant_var_node - graph.update_input_link(in_node, quant_var_node, op_node) + input_name_list = _op_real_in_out_name[op_node.name()][0] + for input_name in input_name_list: + for arg_name in op_node.input(input_name): + in_node = graph._find_node_by_name(op_node.inputs, + arg_name) + if arg_name in dequantized_vars_map: + quant_var_node = dequantized_vars_map[arg_name] + else: + quant_var_node, _ = \ + self._inser_quant_dequant_moving_average_abs_max_op( + graph, in_node, self._quant_bits) + dequantized_vars_map[arg_name] = quant_var_node + graph.update_input_link(in_node, quant_var_node, + op_node) - for op_node in ops: + # Backward stage, update input link + for op_node in all_op_nodes: if op_node.name() in self._quantizable_grad_op_type: for input_name in op_node.input_arg_names(): if input_name in dequantized_vars_map: @@ -1266,6 +1350,21 @@ class AddQuantDequantPass(object): graph.resolve_hazard() return graph + def _is_input_all_not_persistable(self, graph, op_node): + ''' + Analyse the real inputs of the op node are all not persistable. + ''' + is_input_all_not_persistable = True + op_node_name = op_node.name() + + input_name_list = _op_real_in_out_name[op_node_name][0] + for input_name in input_name_list: + for arg_name in op_node.input(input_name): + in_node = graph._find_node_by_name(op_node.inputs, arg_name) + is_input_all_not_persistable = (is_input_all_not_persistable and \ + (not in_node.persistable())) + return is_input_all_not_persistable + def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, quant_bits): """Insert fake_quantize_dequantize_moving_average_abs_max op. diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py index 7cf5b96c01a..7c473e491d6 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py @@ -233,7 +233,10 @@ class TestPostTrainingQuantization(unittest.TestCase): acc1 = np.sum(test_info) / cnt return (throughput, latency, acc1) - def generate_quantized_model(self, model_path, algo="KL"): + def generate_quantized_model(self, + model_path, + algo="KL", + is_full_quantize=False): self.int8_model = os.path.join(os.getcwd(), "post_training_" + self.timestamp) try: @@ -257,7 +260,8 @@ class TestPostTrainingQuantization(unittest.TestCase): model_path=model_path, data_reader=val_reader, algo=algo, - quantizable_op_type=quantizable_op_type) + quantizable_op_type=quantizable_op_type, + is_full_quantize=is_full_quantize) ptq.quantize() ptq.save_quantized_model(self.int8_model) @@ -285,7 +289,9 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): print("Start INT8 post training quantization for {0} on {1} images ...". format(self.model, self.sample_iterations * self.batch_size)) self.generate_quantized_model( - self.model_cache_folder + "/model", algo=self.algo) + self.model_cache_folder + "/model", + algo=self.algo, + is_full_quantize=True) print("Start INT8 inference for {0} on {1} images ...".format( self.model, self.infer_iterations * self.batch_size)) -- GitLab