From aa731e63373c2b370ba0b19e31c9d73e780a1759 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 24 Mar 2021 08:49:48 +0100 Subject: [PATCH] update scale collection and propagation algorithm (#31783) (#31810) --- .../quantization/quant2_int8_mkldnn_pass.py | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index d93a2059bd..68cc8106c9 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -62,9 +62,8 @@ class Quant2Int8MkldnnPass(object): self._ops_to_quantize = _ops_to_quantize self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( [-1]) - self._scale_immutable_ops = [ - 'transpose2', 'reshape2', 'pool2d', 'scale' - ] + self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d'] + self._scale_ops = ['scale'] self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._pool_ops = ['pool2d'] self._mul_ops = ['mul'] @@ -87,8 +86,8 @@ class Quant2Int8MkldnnPass(object): self._reset_pass_idx_and_group('int8') graph = self._label_skip_quantized_op(graph) graph = self._gather_weight_thresholds_from_fake(graph) - graph = self._gather_output_scales_from_attr(graph) graph = self._gather_input_scales_from_fake(graph) + graph = self._gather_output_scales_from_attr(graph) graph = self._remove_fake_ops(graph) graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) @@ -160,12 +159,16 @@ class Quant2Int8MkldnnPass(object): op_node.op()._set_attr("skip_quant", True) return graph - def _gather_input_scales_from_fake(self, graph): - def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor): - scales = self._var_quant_scales - for var_name in var_names: + def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor): + """ + Save quantization scales for variables. Do not overwrite. + """ + scales = self._var_quant_scales + for var_name in var_names: + if var_name not in scales: scales[var_name] = (use_unsigned_int, lod_tensor) + def _gather_input_scales_from_fake(self, graph): # fake_quantize_dequantize_abs_max doesn't have scale value fake_ops = ['fake_quantize_dequantize_moving_average_abs_max'] fake_ops.extend(self._fake_quantize_types) @@ -185,8 +188,8 @@ class Quant2Int8MkldnnPass(object): scale[scale == np.Inf] = 0.0 lod_tensor = self._convert_scale2tensor(scale) use_unsigned_int = False - _add_scale_for_vars([input_name, output_name], use_unsigned_int, - lod_tensor) + self._add_scale_for_vars([input_name, output_name], + use_unsigned_int, lod_tensor) return graph @@ -219,8 +222,8 @@ class Quant2Int8MkldnnPass(object): use_unsigned_int = False for output_name in op.op().outputs(): for out_var_name in op.op().output(output_name): - self._var_quant_scales[out_var_name] = ( - use_unsigned_int, scale_lod_tensor) + self._add_scale_for_vars( + [out_var_name], use_unsigned_int, scale_lod_tensor) return graph @@ -239,24 +242,21 @@ class Quant2Int8MkldnnPass(object): output_name = op.output("Out")[0] tensor_names = [input_name, output_name] - # Scale is not quantized, so if it doesn't have any scales - # to propagate, its tensors won't be added to the waiting list. - if all(name not in self._var_quant_scales for name in tensor_names) \ - and op.name() != 'scale': + if all(name not in self._var_quant_scales + for name in tensor_names): waiting_for_scale.update(tensor_names) continue - - if input_name in self._var_quant_scales: + elif input_name in self._var_quant_scales: self._var_quant_scales[ output_name] = self._var_quant_scales[input_name] elif output_name in self._var_quant_scales: - if op.name() == 'scale': - _update_scale_op_in_scale(op, input_name, - output_name) - else: - self._var_quant_scales[ - input_name] = self._var_quant_scales[ - output_name] + self._var_quant_scales[ + input_name] = self._var_quant_scales[output_name] + elif op.name() in self._scale_ops: + input_name = op.input("X")[0] + output_name = op.output("Out")[0] + if output_name in self._var_quant_scales: + _update_scale_op_in_scale(op, input_name, output_name) return waiting_for_scale waiting_for_scale = _update_scales(graph) -- GitLab