diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py index 2fc9dfac8e7bbd8efdefc52b0f00c72dc2009164..293951200678bfcb4c7d8278282e7259be76bb94 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -27,9 +27,9 @@ class TransformForMkldnnPass(object): 1. Convert int8 range weights with float32 data type, which are generated by the QuantizationFreezePass, to float32 range weights with float32 data type by using the corresponding scales. This conversion is because MKL-DNN INT8 - conv2d kernel now only supports float32 weights input, will do weights - quantization inside the conv2d kernel. - 2. Create the new conv2d op with the converted weights and link its output + conv2d kernel and mul kernel now only support float32 weights input, hence + weights quantization will happen inside the conv2d and mul INT8 kernel. + 2. Create the new conv2d or mul op with the converted weights and link its output to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32 _output" as true 3. Transform fake_quantize_xx op to quantize op @@ -73,13 +73,8 @@ class TransformForMkldnnPass(object): self.InScale = {} self.max_range = {} - self.conv_new_output = {} + self.new_output = {} self.s8_max = 127 - # Temporary code for keeping the mul op as fake quantization - #TODO Intel: Remove the following code when mul int8 mkldnn - # kernel enabled - self.mul_input_id = [] - self.mul_output_id = [] def apply(self, graph): """ @@ -97,7 +92,7 @@ class TransformForMkldnnPass(object): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] # Collect the InScales and max_range to calculate the new scales for MKL-DNN - # INT8 conv2d + # INT8 conv2d and mul for op_node in ops: if op_node.name() in self.dequantize_type: input_name = op_node.input("X")[0] @@ -105,20 +100,14 @@ class TransformForMkldnnPass(object): self.InScale[input_name] = self._load_param(self._scope, scale_name)[0] self.max_range[input_name] = op_node.op().attr("max_range") - self.conv_new_output[input_name] = op_node.output("Out")[0] - # Temporary graph transform on keeping the mul op - # TODO Intel: Remove following code - elif op_node.name() in ['mul']: - input_node = graph._find_node_by_name(op_node.inputs, - op_node.input('X')[0]) - output_node = graph._find_node_by_name(op_node.outputs, - op_node.output('Out')[0]) - self.mul_input_id.append(input_node.id()) - self.mul_output_id.append(output_node.id()) + self.new_output[input_name] = op_node.output("Out")[0] for op_node in ops: - if op_node.name() in self._conv_ops: - self._transform_to_conv_mkldnn(graph, op_node) + if op_node.name() in self._quantizable_ops: + if op_node.name() in self._conv_ops: + self._transform_to_conv_mkldnn(graph, op_node) + else: + self._transform_to_mul_mkldnn(graph, op_node) elif op_node.name() in self.quantize_type: self._transform_to_quantize_mkldnn(graph, op_node) elif op_node.name() in self.dequantize_type: @@ -132,7 +121,7 @@ class TransformForMkldnnPass(object): # Convert int8 range weights to fp32 range weights weight = self._load_param(self._scope, weight_name) w_fp32 = np.divide( - np.multiply(weight, 127), self.max_range[output_name]) + np.multiply(weight, self.s8_max), self.max_range[output_name]) w_fp32 = w_fp32.reshape(weight.shape) self._restore_var(weight_name, w_fp32) input_var_node = graph._find_node_by_name(op_node.inputs, @@ -140,8 +129,8 @@ class TransformForMkldnnPass(object): weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) # Set fake_dequantize_abs_max's output as new output of conv2d - output_var_node = graph._find_node_by_name( - graph.all_var_nodes(), self.conv_new_output[output_name]) + output_var_node = graph._find_node_by_name(graph.all_var_nodes(), + self.new_output[output_name]) attrs = { name: op_node.op().attr(name) for name in op_node.op().attr_names() @@ -157,7 +146,7 @@ class TransformForMkldnnPass(object): # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d scale_in = self.s8_max / self.InScale[output_name] scale_w = [] - scale_w.append(self.max_range[output_name] / self.s8_max) + scale_w = [self.max_range[output_name] / self.s8_max] conv_op_node.set_attr("Scale_weights", scale_w) conv_op_node.set_attr("Scale_in", scale_in) @@ -169,6 +158,50 @@ class TransformForMkldnnPass(object): graph.link_to(conv_op_node, output_var_node) graph.safe_remove_nodes(op_node) + def _transform_to_mul_mkldnn(self, graph, op_node): + # For MKL-DNN INT8 mul, input Y should be the weights + weight_name = op_node.input("Y")[0] + output_name = op_node.output("Out")[0] + # Convert int8 range weights to fp32 range weights + weight = self._load_param(self._scope, weight_name) + w_fp32 = np.divide( + np.multiply(weight, self.s8_max), self.max_range[output_name]) + w_fp32 = w_fp32.reshape(weight.shape) + self._restore_var(weight_name, w_fp32) + input_var_node = graph._find_node_by_name(op_node.inputs, + op_node.input("X")[0]) + weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) + + # Set fake_dequantize_abs_max's output as new output of mul + output_var_node = graph._find_node_by_name(graph.all_var_nodes(), + self.new_output[output_name]) + attrs = { + name: op_node.op().attr(name) + for name in op_node.op().attr_names() + } + + mul_op_node = graph.create_op_node( + op_type='mul', + attrs=attrs, + inputs={'X': input_var_node, + 'Y': weight_var_node}, + outputs={'Out': output_var_node}) + + # Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales + scale_in = self.s8_max / self.InScale[output_name] + scale_w = [] + scale_w = [self.max_range[output_name] / self.s8_max] + + mul_op_node.set_attr("scale_y", scale_w) + mul_op_node.set_attr("scale_x", scale_in) + mul_op_node.set_attr("scale_out", 1.0) + mul_op_node.set_attr("use_mkldnn", 1) + mul_op_node.set_attr("force_fp32_output", 1) + graph.link_to(input_var_node, mul_op_node) + graph.link_to(weight_var_node, mul_op_node) + graph.link_to(mul_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + def _transform_to_quantize_mkldnn(self, graph, op_node): """ Transform fake_quantize_xx op to quantize mkldnn op in the graph. @@ -177,32 +210,26 @@ class TransformForMkldnnPass(object): op_node.input("X")[0]) output_var_node = graph._find_node_by_name(op_node.outputs, op_node.output("Out")[0]) - if output_var_node.id() in self.mul_input_id: - return - else: - scale_in = self.s8_max / self._load_param( - self._scope, op_node.input("InScale")[0])[0] - quant_op_node = graph.create_op_node( - op_type='quantize', - attrs={ - 'data_format': 'MKLDNNLAYOUT', - 'use_mkldnn': 1, - 'Scale': scale_in, - 'is_negative_input': 1 - }, - inputs={'Input': input_var_node}, - outputs={'Output': output_var_node}) - graph.link_to(input_var_node, quant_op_node) - graph.link_to(quant_op_node, output_var_node) - graph.safe_remove_nodes(op_node) + scale_in = self.s8_max / self._load_param( + self._scope, op_node.input("InScale")[0])[0] + quant_op_node = graph.create_op_node( + op_type='quantize', + attrs={ + 'data_format': 'MKLDNNLAYOUT', + 'use_mkldnn': 1, + 'Scale': scale_in, + 'is_negative_input': 1 + }, + inputs={'Input': input_var_node}, + outputs={'Output': output_var_node}) + graph.link_to(input_var_node, quant_op_node) + graph.link_to(quant_op_node, output_var_node) + graph.safe_remove_nodes(op_node) def _remove_fake_dequantize_op(self, graph, op_node): input_var_node = graph._find_node_by_name(op_node.inputs, op_node.input("X")[0]) - if input_var_node.id() in self.mul_output_id: - return - else: - graph.safe_remove_nodes(op_node) + graph.safe_remove_nodes(op_node) def _load_param(self, scope, param_name): return np.array(scope.find_var(param_name).get_tensor()) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py index 90cc28b3aaf95f9696cc177d8e9a2381abae3ea6..81a31ba7d2efb30b8225cd268c8bfd28adcb4323 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py @@ -55,9 +55,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): self.quantizable_op_and_inputs = { 'conv2d': ['Input', 'Filter'], 'depthwise_conv2d': ['Input', 'Filter'], - # Mul int8 op is under internal test - # TODO Update this when mul op is merged - #'mul': ['X', 'Y'] + 'mul': ['X', 'Y'] } def check_program(self, program): @@ -162,11 +160,15 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): activation_quant_type + '_' + weight_quant_type, marked_nodes) mkldnn_program = test_graph.to_program() - w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + + # Check the transformation weights of conv2d and mul + conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor()) # Check if weights are still integer - self.assertFalse(self.isinteger(np.sum(w_mkldnn))) + self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn))) + self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn))) - # Check if the conv2d output is rightly linked to fake_dequantize's + # Check if the conv2d output and mul output are correctly linked to fake_dequantize's # output self.check_program(mkldnn_program) if not for_ci: