From 0e235e585182e4f88ef5db9c678783c284507934 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Thu, 27 Jan 2022 07:47:40 +0100 Subject: [PATCH] Update passes in quant2_int8_mkldnn_pass (#38912) * Upadate pass in quant2_int8_mkldnn_pass * Back to the previous scale_matmul order * Change place of cpu_quantize_placement_pass --- .../quantization/quant2_int8_mkldnn_pass.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 0251dd693f6..92a335a73dc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -397,6 +397,7 @@ class Quant2Int8MkldnnPass(object): def _optimize_fp32_graph(self, graph): graph = self._update_activations(graph) graph = self._remove_ctrl_vars(graph) + graph = self._apply_pass(graph, 'layer_norm_fuse_pass') graph = self._apply_pass(graph, 'attention_lstm_fuse_pass') graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass') # graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass') @@ -409,24 +410,39 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'multi_gru_fuse_pass') graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass') graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass') + graph = self._apply_pass(graph, 'squeeze2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'reshape2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'flatten2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass') graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass') graph = self._apply_pass(graph, 'is_test_pass') graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass') graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass') + graph = self._apply_pass(graph, 'matmul_scale_fuse_pass') graph = self._apply_pass(graph, 'map_matmul_to_mul_pass') + graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]) graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_affine_channel_fuse_pass') + graph = self._apply_pass(graph, + 'conv_eltwiseadd_affine_channel_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_eltwiseadd_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_leaky_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]) graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') @@ -436,6 +452,8 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass') + graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') + graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') # the following pass should be the last one since it will work on all fused ops. graph = self._apply_pass(graph, 'runtime_context_cache_pass') return graph @@ -638,15 +656,15 @@ class Quant2Int8MkldnnPass(object): return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW' def _quantize_fp32_graph(self, graph): - graph = self._apply_pass( - graph, 'cpu_quantize_placement_pass', - ['quantize_enabled_op_types', 'quantize_excluded_op_ids'], - [self._ops_to_quantize, self._find_avg_pooling_ids(graph)]) graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'reshape_transpose_matmul_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass') + graph = self._apply_pass( + graph, 'cpu_quantize_placement_pass', + ['quantize_enabled_op_types', 'quantize_excluded_op_ids'], + [self._ops_to_quantize, self._find_avg_pooling_ids(graph)]) graph = self._apply_pass( graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], [self._var_quant_scales, self._get_data_layout(graph)]) -- GitLab