From 5eff6f0147bbf5fff491ac2d56fe719a7c921592 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 13 Dec 2022 10:31:36 +0800 Subject: [PATCH] support conv1d quant & skip calibrate zero-size tensor (#48912) --- .../post_training_quantization.py | 120 +++++++++------ .../slim/quantization/quantization_pass.py | 138 +++++++++++++----- 2 files changed, 180 insertions(+), 78 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index fa57a9bd746..5ed3be2622a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -398,6 +398,9 @@ class PostTrainingQuantization: self._best_calibration_loss = {} # The threshold for algo = abs_max, mse or avg self._quantized_threshold = {} + # If the tensor is zero-size during any calibration step, + # it will be stored in self._zero_size_var_names + self._zero_size_var_names = set() self._same_scale_tensor_list = same_scale_tensor_list self._freeze_model = freeze_model self._scale_dict = scale_dict @@ -465,9 +468,12 @@ class PostTrainingQuantization: if self._algo == 'avg': for var_name in self._quantized_act_var_name: + if var_name not in self._quantized_var_avg: + continue self._quantized_threshold[var_name] = np.array( self._quantized_var_avg[var_name] ).mean() + if self._algo in ["KL", "hist"]: self._calculate_kl_hist_threshold() @@ -741,6 +747,9 @@ class PostTrainingQuantization: _logger.info("MSE searching stage ...") for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value @@ -792,6 +801,9 @@ class PostTrainingQuantization: _logger.info("EMD searching stage ...") for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value @@ -845,6 +857,9 @@ class PostTrainingQuantization: for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue abs_max_value = float(np.max(np.abs(var_tensor))) if var_name not in self._quantized_var_avg: self._quantized_var_avg[var_name] = [] @@ -857,7 +872,6 @@ class PostTrainingQuantization: ) ) self._quantized_var_avg[var_name].append(abs_avg_value) - continue def _sample_abs_max(self): if self._quantized_threshold == {}: @@ -884,6 +898,9 @@ class PostTrainingQuantization: for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue abs_max_value = float(np.max(np.abs(var_tensor))) if (var_name not in self._quantized_threshold) or ( abs_max_value > self._quantized_threshold[var_name] @@ -916,6 +933,9 @@ class PostTrainingQuantization: for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) if (var_name not in self._quantized_var_min) or ( @@ -930,6 +950,11 @@ class PostTrainingQuantization: def _sample_histogram(self): for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if (not var_tensor.any()) or ( + var_name not in self._sampling_act_histogram + ): + self._zero_size_var_names.add(var_name) + continue var_tensor_abs = np.abs(var_tensor) bins = self._sampling_act_histogram[var_name][1] hist, _ = np.histogram(var_tensor_abs, bins=bins) @@ -964,6 +989,9 @@ class PostTrainingQuantization: for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue abs_max_value = float(np.max(np.abs(var_tensor))) q_max = 2 ** (self._activation_bits - 1) - 1 scale8 = abs_max_value / q_max @@ -1020,6 +1048,9 @@ class PostTrainingQuantization: ''' for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) + if not var_tensor.any(): + self._zero_size_var_names.add(var_name) + continue var_tensor = np.abs(var_tensor) min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) @@ -1039,6 +1070,10 @@ class PostTrainingQuantization: Based on the min/max value, init the sampling_act_histogram. ''' for var_name in self._quantized_act_var_name: + if (var_name in self._zero_size_var_names) and ( + var_name not in self._sampling_act_abs_min_max + ): + continue if var_name not in self._sampling_act_histogram: min_val = self._sampling_act_abs_min_max[var_name][0] max_val = self._sampling_act_abs_min_max[var_name][1] @@ -1077,6 +1112,10 @@ class PostTrainingQuantization: self._quantized_var_threshold[var_name] = weight_threshold for var_name in self._quantized_act_var_name: + if (var_name in self._zero_size_var_names) and ( + var_name not in self._sampling_act_histogram + ): + continue hist, hist_edeges = self._sampling_act_histogram[var_name] if self._algo == "KL": bin_width = hist_edeges[1] - hist_edeges[0] @@ -1162,7 +1201,6 @@ class PostTrainingQuantization: if self._same_scale_tensor_list is not None: for tensor_list in self._same_scale_tensor_list: max_scale = None - tmp_tensor_list = [] for tensor_name in tensor_list: if '#' in tensor_name: real_tensor_name, opera, scalar = tensor_name.split( @@ -1261,21 +1299,40 @@ class PostTrainingQuantization: self._calibration_scales = {} def save_info( - op_node, out_var_name, threshold_map, out_info_name, quantized_type + op_node, + out_var_name, + threshold_map, + out_info_name, + argname_index, + quantized_type, ): - assert ( - out_var_name in threshold_map - ), "The output ({}) of {} node does not have threshold.".format( - out_var_name, op_node.type - ) + if (out_var_name in self._zero_size_var_names) and ( + out_var_name not in threshold_map + ): + _logger.warning( + "{} is zero-size tensor and unable to calibrate, so skip quant it.".format( + out_var_name + ) + ) + return + else: + assert ( + out_var_name in threshold_map + ), "The output ({}) of {} node does not have threshold.".format( + out_var_name, op_node.type + ) if self._onnx_format: # For easy extension, every var_node set a dict to save parameters of quant. - self._calibration_scales[var_name] = {} - self._calibration_scales[var_name]['scale'] = threshold_map[ - var_name + self._calibration_scales[out_var_name] = {} + self._calibration_scales[out_var_name]['scale'] = threshold_map[ + out_var_name ] else: - op_node._set_attr(out_info_name, threshold_map[var_name]) + op_node._set_attr(out_info_name, threshold_map[out_var_name]) + op_node._set_attr( + argname_index[0] + str(argname_index[1]) + "_threshold", + threshold_map[out_var_name], + ) op_node._set_attr("with_quant_attr", True) if op_node.type in self._quantizable_op_type: op._set_attr("quantization_type", quantized_type) @@ -1285,52 +1342,23 @@ class PostTrainingQuantization: assert argname_index is not None, ( out_var_name + " is not the output of the op" ) - if self._algo == "KL": - # For compatibility, we save output threshold by two methods. - save_info( - op_node, - out_var_name, - self._quantized_var_threshold, - "out_threshold", - "post_kl", - ) - save_info( - op_node, - out_var_name, - self._quantized_var_threshold, - argname_index[0] + str(argname_index[1]) + "_threshold", - "post_kl", - ) - elif self._algo == "hist": + if self._algo in ["KL", "hist"]: # For compatibility, we save output threshold by two methods. save_info( op_node, out_var_name, self._quantized_var_threshold, "out_threshold", - "post_hist", - ) - save_info( - op_node, - out_var_name, - self._quantized_var_threshold, - argname_index[0] + str(argname_index[1]) + "_threshold", - "post_hist", + argname_index, + "post_" + str(self._algo).lower(), ) - elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]: save_info( op_node, out_var_name, self._quantized_threshold, "out_threshold", - "post_" + str(self._algo), - ) - save_info( - op_node, - out_var_name, - self._quantized_threshold, - argname_index[0] + str(argname_index[1]) + "_threshold", + argname_index, "post_" + str(self._algo), ) elif self._algo == "min_max": @@ -1339,6 +1367,7 @@ class PostTrainingQuantization: out_var_name, self._quantized_var_min, "out_min", + argname_index, "post_min_max", ) save_info( @@ -1346,6 +1375,7 @@ class PostTrainingQuantization: out_var_name, self._quantized_var_max, "out_max", + argname_index, "post_min_max", ) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index eddbf68fe1a..81556d83f3a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -2134,7 +2134,9 @@ class InsertQuantizeLinear: self._moving_rate = moving_rate self._scale_dict = scale_dict - def insert_quant_op(self, graph, var_node, var_name=None): + def insert_quant_op( + self, graph, var_node, var_name=None, scale_var_node=None + ): assert var_node.is_var(), '{} is not a var'.format(var_node.name()) var_name = var_node.name() if not var_name else var_name quant_var_node = graph.create_var_node( @@ -2143,40 +2145,43 @@ class InsertQuantizeLinear: shape=var_node.shape(), var_dtype=var_node.dtype(), ) - data_type = ( - 'float64' - if var_node.dtype() == core.VarDesc.VarType.FP64 - else 'float32' - ) - scale_name = self._quantized_scale_name(var_name) - if self.channel_wise: - scale_var_shape = var_node.shape()[self.quant_axis] - scale_var_type = core.VarDesc.VarType.LOD_TENSOR - init_scale_value = ( - np.ones(scale_var_shape, dtype=data_type) * _SCALE_DEFAULT_VALUE + if not scale_var_node: + data_type = ( + 'float64' + if var_node.dtype() == core.VarDesc.VarType.FP64 + else 'float32' ) - else: - scale_var_shape = 1 - scale_var_type = var_node.type() - init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) + scale_name = self._quantized_scale_name(var_name) + if self.channel_wise: + scale_var_shape = var_node.shape()[self.quant_axis] + scale_var_type = core.VarDesc.VarType.LOD_TENSOR + init_scale_value = ( + np.ones(scale_var_shape, dtype=data_type) + * _SCALE_DEFAULT_VALUE + ) + else: + scale_var_shape = 1 + scale_var_type = var_node.type() + init_scale_value = np.array( + [_SCALE_DEFAULT_VALUE], dtype=data_type + ) - if ( - self._scale_dict is not None - and var_node.name() in self._scale_dict.keys() - ): - init_scale_value = np.array( - [self._scale_dict[var_node.name()]], dtype=data_type + if ( + self._scale_dict is not None + and var_node.name() in self._scale_dict.keys() + ): + init_scale_value = np.array( + [self._scale_dict[var_node.name()]], dtype=data_type + ) + scale_var_node = graph.create_persistable_node( + name=scale_name, + var_type=scale_var_type, + shape=[scale_var_shape], + var_dtype=var_node.dtype(), + ) + _init_var_node( + scale_var_node, init_scale_value, self._scope, self._place ) - - scale_var_node = graph.create_persistable_node( - name=scale_name, - var_type=scale_var_type, - shape=[scale_var_shape], - var_dtype=var_node.dtype(), - ) - _init_var_node( - scale_var_node, init_scale_value, self._scope, self._place - ) zero_point_node = None if zero_point_node is None: @@ -2510,6 +2515,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): def _transform_forward(self, graph, op): op.op()._set_attr("quantization_type", "qat_with_weight") + weight_scale_node = None inputs = op.inputs for var_node in inputs: if var_node.name() not in op.input_arg_names(): @@ -2595,7 +2601,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ) self.dequantized_vars[name] = dequant_var_node + if is_weight: + weight_scale_node = scale_var_node graph.update_input_link(var_node, dequant_var_node, op) + return weight_scale_node def _transform_backward(self, graph, op): for var_node in op.inputs: @@ -2610,11 +2619,71 @@ class QuantizationTransformPassV2(QuantizationTransformPass): for var_node in op.inputs: if var_node.name() not in op.input_arg_names(): continue - name = var_node.name() if var_node.name() in self.persistable_vars: has_weight = True return has_weight + def _quant_conv1d(self, graph, op): + # conv1d in inference is a combination of unsqueeze2 + conv2d + if ("conv2d" not in op.name()) or ( + "unsqueeze2" not in op.input("Filter")[0] + ): + return + conv_weight_var_name = op.input("Filter")[0] + # unsqueeze2 and conv2d will share weight scale + weight_scale_node = None + # quant unsqueeze2 + for _op in graph.all_op_nodes(): + var_names = utils._get_op_output_var_names(_op) + if conv_weight_var_name in var_names and self._has_weight(_op): + weight_scale_node = self._transform_forward(graph, _op) + # insert qdq before conv2d + for var_node in op.inputs: + quant_bits = ( + self._weight_bits + if var_node.name() == conv_weight_var_name + else self._activation_bits + ) + quant_type = ( + self._weight_quantize_type + if var_node.name() == conv_weight_var_name + else self._activation_quantize_type + ) + quant_axis = -1 + channel_wise = False + if quant_type == 'channel_wise_abs_max': + channel_wise = True + quant_axis = ( + 1 if op.name() in utils._channelwise_quant_axis1_ops else 0 + ) + insert_quant_pass = InsertQuantizeLinear( + self._place, + self._scope, + quant_bits=quant_bits, + quant_axis=quant_axis, + channel_wise=channel_wise, + moving_rate=self._moving_rate, + is_test=self._is_test, + ) + scale_var_node = ( + weight_scale_node + if var_node.name() == conv_weight_var_name + else None + ) + ( + quant_var_node, + scale_var_node, + ) = insert_quant_pass.insert_quant_op( + graph, + var_node, + var_name=var_node.name(), + scale_var_node=scale_var_node, + ) + dequant_var_node = insert_quant_pass.insert_dequant_op( + graph, quant_var_node, scale_var_node + ) + graph.update_input_link(var_node, dequant_var_node, op) + def apply(self, graph): """ Quantize the graph for training process. According to weight and @@ -2664,6 +2733,9 @@ class QuantizationTransformPassV2(QuantizationTransformPass): op ): self._transform_forward(graph, op) + else: # op is not persistable + # support conv1d quantization + self._quant_conv1d(graph, op) t.update() # The loop for renaming the inputs of backward op. for op in ops: -- GitLab