From 3fe71d0a2774e0823e8cb05b8f4c8acf93c004a6 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 5 Jan 2021 16:53:10 +0800 Subject: [PATCH] [cherry-pick 2.0] Support dygraph quant model and avoid the scale to be infinity (#30098) * fix ininite scale values (#29386) * Support dygraph quant model (#29927) * Avoid the scale to be infinity in quant2_int8_mkldnn_pass, test=develop * support quantized model for paddle2.0 dygraph, test=develop Co-authored-by: Wojciech Uss --- .../quantization/quant2_int8_mkldnn_pass.py | 45 ++++++++++++------- .../tests/test_quant2_int8_mkldnn_pass.py | 6 +-- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 98123a474c9..7e1db69703c 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -49,11 +49,14 @@ class Quant2Int8MkldnnPass(object): self._fake_quantize_types = [ 'fake_quantize_moving_average_abs_max', 'fake_quantize_range_abs_max', - 'fake_quantize_dequantize_moving_average_abs_max' ] self._fake_dequantize_types = [ 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] + self._fake_quantize_dequantize_types = [ + 'fake_quantize_dequantize_abs_max', + 'fake_quantize_dequantize_moving_average_abs_max' + ] self._ops_to_quantize = _ops_to_quantize self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( [-1]) @@ -137,8 +140,12 @@ class Quant2Int8MkldnnPass(object): for var_name in var_names: scales[var_name] = (use_unsigned_int, lod_tensor) + # fake_quantize_dequantize_abs_max doesn't have scale value + fake_ops = ['fake_quantize_dequantize_moving_average_abs_max'] + fake_ops.extend(self._fake_quantize_types) + for op in graph.all_op_nodes(): - if op.name() in self._fake_quantize_types: + if op.name() in fake_ops: bit_length = op.op().attr("bit_length") assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format( bit_length) @@ -146,9 +153,10 @@ class Quant2Int8MkldnnPass(object): input_name = op.input("X")[0] scale_name = op.input("InScale")[0] output_name = op.output("Out")[0] - # Gather new weights scale after folding batchnorm in convolution + # Gather new weight scales after folding batchnorm in convolution scale = np.array(1.0 / self._load_param( self._scope, scale_name)[0]).astype(np.float64) + scale[scale == np.Inf] = 0.0 lod_tensor = self._convert_scale2tensor(scale) use_unsigned_int = False _add_scale_for_vars([input_name, output_name], use_unsigned_int, @@ -163,13 +171,14 @@ class Quant2Int8MkldnnPass(object): if op.op().has_attr("max_range"): _max_range = np.array(op.op().attr("max_range")).astype( np.float64) - self._weight_scales[input_name] = _max_range + self._weight_scales[input_name] = np.array( + self._s8_max * self._s8_max / + _max_range).astype(np.float64) else: scale_name = op.input("Scales")[0] - scale = np.array( - self._s8_max * self._s8_max / self._load_param( - self._scope, scale_name)).astype(np.float64) - self._weight_scales[input_name] = scale + self._weight_scales[input_name] = np.array( + self._load_param(self._scope, scale_name)).astype( + np.float64) return graph @@ -179,6 +188,7 @@ class Quant2Int8MkldnnPass(object): attr_scale = op.op().attr("out_threshold") if attr_scale == 0.0: continue scale = np.array(1.0 / attr_scale).astype(np.float64) + scale[scale == np.Inf] = 0.0 scale_lod_tensor = self._convert_scale2tensor(scale) use_unsigned_int = False for output_name in op.op().outputs(): @@ -240,9 +250,9 @@ class Quant2Int8MkldnnPass(object): for op in graph.all_op_nodes(): if op.name() in self._fake_quantize_types: self._remove_fake_quantize(graph, op) - - for op in graph.all_op_nodes(): - if op.name() in self._fake_dequantize_types: + elif op.name() in self._fake_dequantize_types: + self._remove_fake_dequantize(graph, op) + elif op.name() in self._fake_quantize_dequantize_types: self._remove_fake_dequantize(graph, op) return graph @@ -287,10 +297,15 @@ class Quant2Int8MkldnnPass(object): ]) def _dequantize_weights(self, graph): + def _is_int8_weights(op_node, weight_name): + weight_var_name = op_node.input(weight_name)[0] + weight = self._load_param(self._scope, weight_var_name) + return np.all(np.mod(weight, 1) == 0) + for op in graph.all_op_nodes(): - if op.name() in self._conv_ops: + if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"): self._dequantize_op_weights(graph, op, "Filter", "Output") - elif op.name() in self._mul_ops: + elif op.name() in self._mul_ops and _is_int8_weights(op, "Y"): self._dequantize_op_weights(graph, op, "Y", "Out") return graph @@ -301,9 +316,9 @@ class Quant2Int8MkldnnPass(object): scales = self._weight_scales[output_var_name] weight = self._load_param(self._scope, weight_var_name) if scales.size == 1 or scales.size == weight.shape[0]: - w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T + w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T elif len(weight.shape) > 1 and scales.size == weight.shape[1]: - w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) + w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales) else: raise ValueError( "The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}." diff --git a/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py index 7f9209c8b3f..0c48f668e54 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py @@ -187,9 +187,9 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): assert np.allclose( self.scope.find_var("mul_weights").get_tensor(), - [[127, 63.5, 42.3333, 31.75, 25.4], - [127, 63.5, 42.3333, 31.75, 25.4], - [127, 63.5, 42.3333, 31.75, 25.4]]) + [[1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.], + [1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.], + [1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.]]) param = self.scope.var("mul_weights").get_tensor() param.set(self.variables_mul["mul_weights_bad"], self.place) -- GitLab