diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index 44098e56e161ef9b71476d6219c3fc125c36b8b7..2c91b7599d4cd4ac3e1805e4714ec33114878434 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -15,10 +15,11 @@ import numpy as np from .... import core from ....framework import IrGraph -from ....framework import IrNode __all__ = ['Qat2Int8MkldnnPass'] +OpRole = core.op_proto_and_checker_maker.OpRole + class Qat2Int8MkldnnPass(object): """ @@ -62,6 +63,7 @@ class Qat2Int8MkldnnPass(object): self._pool_ops = ['pool2d'] self._mul_ops = ['mul'] self._fc_ops = ['fc'] + self._matmul_ops = ['matmul'] self._weight_scales = {} # Collect the Input and Output sclaes from Fake QAT models self._var_quant_scales = {} @@ -79,9 +81,9 @@ class Qat2Int8MkldnnPass(object): graph = self._compute_weight_scales(graph) graph = self._update_relu_output_scales(graph) graph = self._propagate_scales(graph) - graph = self._set_dummy_fc_out_scales(graph) + graph = self._set_dummy_out_scales(graph) graph = self._quantize_fp32_graph(graph) - graph = self._remove_unused_var_nodes(graph) + graph = self._cleanup(graph) return graph def apply_fp32(self, graph): @@ -92,7 +94,7 @@ class Qat2Int8MkldnnPass(object): graph = self._remove_fake_ops(graph) graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) - graph = self._remove_unused_var_nodes(graph) + graph = self._cleanup(graph) return graph def _convert_scale2tensor(self, scale): @@ -176,23 +178,29 @@ class Qat2Int8MkldnnPass(object): return graph - def _set_dummy_fc_out_scales(self, graph): + def _set_dummy_out_scales(self, graph): ''' - For the output tensors of FC that do not have an assigned scale, + For the output tensors of fc, conv2d and matmul ops that do not have an assigned scale, assign a dummy scale (same scale as input), so that the quantize pass - won't fail. In the end these scales aren't used, since FCs that + won't fail. In the end these scales aren't used, since the ops that have an unassigend output scale will have a force_fp32_output attr set to True. ''' + + def _set_scale(op, op_types, input_names, output_name): + scales = self._var_quant_scales + should_set = op.name() in op_types \ + and op.output(output_name)[0] not in scales \ + and all(op.input(input_name)[0] in scales for input_name in input_names) + if should_set: + output_var_name = op.output(output_name)[0] + input_var_name = op.input(input_names[0])[0] + scales[output_var_name] = scales[input_var_name] + for op in graph.all_op_nodes(): - if op.name() in self._fc_ops: - input_name = op.input("Input")[0] - output_name = op.output("Out")[0] - if input_name in self._var_quant_scales and \ - output_name not in self._var_quant_scales: - # use input scale as a "dummy" scale - self._var_quant_scales[ - output_name] = self._var_quant_scales[input_name] + _set_scale(op, self._conv_ops, ["Input"], "Output") + _set_scale(op, self._fc_ops, ["Input"], "Out") + _set_scale(op, self._matmul_ops, ["X", "Y"], "Out") return graph @@ -358,6 +366,15 @@ class Qat2Int8MkldnnPass(object): self._remove_unused_var_nodes(graph) return graph + def _cleanup(self, graph): + # remove dropout ops + graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') + # make some MKL-DNN ops working inplace + graph = self._apply_pass(graph, 'mkldnn_inplace_pass') + graph = self._remove_unused_var_nodes(graph) + graph = self._set_op_role_forward(graph) + return graph + def _remove_unused_var_nodes(self, graph): all_used_vars = set() ops = graph.all_op_nodes() @@ -376,8 +393,14 @@ class Qat2Int8MkldnnPass(object): graph.safe_remove_nodes(all_unused_vars) return graph + def _set_op_role_forward(self, graph): + ops = graph.all_op_nodes() + for op in ops: + op.set_attr("op_role", OpRole.Forward) + return graph + def _compute_weight_scales(self, graph): - def _compute_var_scales(ops, out_name, w_name, axis): + def _compute_var_scales(ops, w_name, axis): for op in graph.all_op_nodes(): if op.op().type() in ops: weight_var_name = op.input(w_name)[0] @@ -394,8 +417,8 @@ class Qat2Int8MkldnnPass(object): self._var_quant_scales[weight_var_name] = (use_unsigned_int, lod_tensor) - _compute_var_scales(self._conv_ops, "Output", "Filter", axis=1) - _compute_var_scales(self._fc_ops, "Out", "W", axis=0) + _compute_var_scales(self._conv_ops, "Filter", axis=1) + _compute_var_scales(self._fc_ops, "W", axis=0) return graph def _find_avg_pooling_ids(self, graph): diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 58a3827ce6e2f9ee8de5dd25dc0ba4510f6cf8a4..ec3d23a1c06a4a8555f95167ecec39adb6d45048 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -231,7 +231,7 @@ if(LINUX AND WITH_MKLDNN) ### QATv2 for NLP - set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2") + set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul") set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")