diff --git a/python/paddle/static/quantization/quantization_pass.py b/python/paddle/static/quantization/quantization_pass.py index 606e88320cb5e0c2b24725cf9445042794b90d84..c53902b20169d6035d588db7ce9f9339b1a82979 100644 --- a/python/paddle/static/quantization/quantization_pass.py +++ b/python/paddle/static/quantization/quantization_pass.py @@ -298,6 +298,7 @@ class QuantizationTransformPass: def _transform_forward(graph, op): op.op()._set_attr("quantization_type", "qat_with_weight") op.op()._set_attr("with_quant_attr", True) + op_role = op.op().attr("op_role") inputs = op.inputs for var_node in inputs: if var_node.name() not in op.input_arg_names(): @@ -368,7 +369,12 @@ class QuantizationTransformPass: quant_var_node, scale_var_node, ) = self._insert_channel_quant_op( - graph, var_node, name, quant_bits, quant_axis + graph, + var_node, + name, + quant_bits, + quant_axis, + op_role, ) dequant_var_node = self._insert_channel_dequant_op( graph, @@ -376,13 +382,23 @@ class QuantizationTransformPass: [scale_var_node], [quant_bits], quant_axis, + op_role, ) else: quant_var_node, scale_var_node = self._insert_quant_op( - graph, var_node, name, quant_bits, quant_type + graph, + var_node, + name, + quant_bits, + quant_type, + op_role, ) dequant_var_node = self._insert_dequant_op( - graph, quant_var_node, scale_var_node, quant_bits + graph, + quant_var_node, + scale_var_node, + quant_bits, + op_role, ) dequantized_vars[name] = dequant_var_node graph.update_input_link(var_node, dequant_var_node, op) @@ -476,24 +492,28 @@ class QuantizationTransformPass: graph.link_to(increment_op, global_step_out) self._global_step = global_step_out - def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type): + def _insert_quant_op( + self, graph, var_node, name, quant_bits, quant_type, op_role + ): """ Insert fake_quantize_op in the graph. """ if quant_type == 'abs_max': return self._insert_quant_abs_max_op( - graph, var_node, name, quant_bits + graph, var_node, name, quant_bits, op_role ) elif quant_type == 'range_abs_max': return self._insert_quant_range_abs_max_op( - graph, var_node, name, quant_bits + graph, var_node, name, quant_bits, op_role ) elif quant_type == 'moving_average_abs_max': return self._insert_quant_moving_average_abs_max_op( - graph, var_node, name, quant_bits + graph, var_node, name, quant_bits, op_role ) - def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits): + def _insert_quant_abs_max_op( + self, graph, var_node, name, quant_bits, op_role + ): """ Insert fake_quantize_abs_max op in the graph. """ @@ -528,10 +548,7 @@ class QuantizationTransformPass: quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', - attrs={ - 'bit_length': quant_bits, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, - }, + attrs={'bit_length': quant_bits, 'op_role': op_role}, inputs={'X': var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node}, ) @@ -540,7 +557,9 @@ class QuantizationTransformPass: graph.link_to(quant_op_node, scale_var_node) return quant_var_node, scale_var_node - def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits): + def _insert_quant_range_abs_max_op( + self, graph, var_node, name, quant_bits, op_role + ): """ Insert fake_quantize_range_abs_max on the graph. """ @@ -605,7 +624,7 @@ class QuantizationTransformPass: 'window_size': self._window_size, 'bit_length': quant_bits, 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op_role, } quant_op_node = graph.create_op_node( op_type='fake_quantize_range_abs_max', @@ -626,7 +645,7 @@ class QuantizationTransformPass: return quant_var_node, scale_out_node def _insert_quant_moving_average_abs_max_op( - self, graph, var_node, name, quant_bits + self, graph, var_node, name, quant_bits, op_role ): """Insert fake_quantize_moving_average_abs_max""" quant_var_node = graph.create_var_node( @@ -706,7 +725,7 @@ class QuantizationTransformPass: 'bit_length': quant_bits, 'moving_rate': self._moving_rate, 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op_role, } quant_op_node = graph.create_op_node( @@ -730,7 +749,7 @@ class QuantizationTransformPass: return quant_var_node, scale_out_node def _insert_channel_quant_op( - self, graph, var_node, name, quant_bits, quant_axis + self, graph, var_node, name, quant_bits, quant_axis, op_role ): """ Insert fake_channel_wise_quantize_abs_max op in the graph. @@ -771,7 +790,7 @@ class QuantizationTransformPass: 'bit_length': quant_bits, 'quant_axis': quant_axis, 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op_role, }, inputs={'X': var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node}, @@ -781,7 +800,9 @@ class QuantizationTransformPass: graph.link_to(quant_op_node, scale_var_node) return quant_var_node, scale_var_node - def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): + def _insert_dequant_op( + self, graph, var_node, scale_var_node, quant_bits, op_role + ): """ Insert fake_dequantize_op in the graph. """ @@ -796,10 +817,7 @@ class QuantizationTransformPass: max_range = (1 << (quant_bits - 1)) - 1 dequant_op_node = graph.create_op_node( op_type='fake_dequantize_max_abs', - attrs={ - 'max_range': float(max_range), - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, - }, + attrs={'max_range': float(max_range), 'op_role': op_role}, inputs={'X': var_node, 'Scale': scale_var_node}, outputs={'Out': dequant_var_node}, ) @@ -809,7 +827,7 @@ class QuantizationTransformPass: return dequant_var_node def _insert_channel_dequant_op( - self, graph, var_node, scale_var_nodes, quant_bits, quant_axis + self, graph, var_node, scale_var_nodes, quant_bits, quant_axis, op_role ): """ Insert fake_channel_wise_dequantize_max_abs in the graph. @@ -827,7 +845,7 @@ class QuantizationTransformPass: attrs={ 'quant_bits': quant_bits, 'quant_axis': quant_axis, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op_role, }, inputs={'X': var_node, 'Scales': scale_var_nodes}, outputs={'Out': dequant_var_node}, @@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass: in_node = graph._find_node_by_name( op.outputs, output_var_name ) - if in_node.dtype() not in [ - core.VarDesc.VarType.FP64, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, - ]: + if ( + in_node.dtype() + not in [ + core.VarDesc.VarType.FP64, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, + ] + or '@GRAD' in in_node.name() + ): continue if in_node.dtype() == core.VarDesc.VarType.FP64: @@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass: attrs = { 'moving_rate': self._moving_rate, 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op.op().attr("op_role"), } scale_op_node = graph.create_op_node( op_type='moving_average_abs_max_scale', @@ -1953,7 +1975,10 @@ class AddQuantDequantPass: quant_var_node, _, ) = self._inser_quant_dequant_moving_average_abs_max_op( - graph, in_node, self._quant_bits + graph, + in_node, + self._quant_bits, + op_node.op().attr("op_role"), ) dequantized_vars_map[arg_name] = quant_var_node graph.update_input_link( @@ -1978,7 +2003,7 @@ class AddQuantDequantPass: return graph def _inser_quant_dequant_moving_average_abs_max_op( - self, graph, var_node, quant_bits + self, graph, var_node, quant_bits, op_role ): """Insert fake_quantize_dequantize_moving_average_abs_max op.""" quant_var_node = graph.create_var_node( @@ -2068,7 +2093,7 @@ class AddQuantDequantPass: 'bit_length': quant_bits, 'moving_rate': self._moving_rate, 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'op_role': op_role, } quant_op_node = graph.create_op_node( @@ -2131,7 +2156,12 @@ class InsertQuantizeLinear: self._scale_dict = scale_dict def insert_quant_op( - self, graph, var_node, var_name=None, scale_var_node=None + self, + graph, + var_node, + var_name=None, + scale_var_node=None, + op_role=core.op_proto_and_checker_maker.OpRole.Forward, ): assert var_node.is_var(), f'{var_node.name()} is not a var' var_name = var_node.name() if not var_name else var_name @@ -2200,7 +2230,7 @@ class InsertQuantizeLinear: inputs["ZeroPoint"] = zero_point_node attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} - attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward + attrs["op_role"] = op_role outputs = {"Y": quant_var_node} if not self._is_test: scale_out_node = graph.create_var_node_from_desc( @@ -2271,7 +2301,7 @@ class InsertQuantizeLinear: graph.link_to(quant_op_node, scale_out_node) return quant_var_node, scale_var_node - def insert_dequant_op(self, graph, var_node, scale_var_node): + def insert_dequant_op(self, graph, var_node, scale_var_node, op_role): assert var_node.is_var(), f'{var_node.name()} is not a var' dequant_var_node = graph.create_var_node( @@ -2301,7 +2331,7 @@ class InsertQuantizeLinear: inputs["ZeroPoint"] = zero_point_node attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} - attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward + attrs["op_role"] = op_role quant_op_node = graph.create_op_node( op_type="dequantize_linear", @@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): def _transform_forward(self, graph, op): op.op()._set_attr("quantization_type", "qat_with_weight") + op_role = op.op().attr("op_role") weight_scale_node = None inputs = op.inputs for var_node in inputs: @@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): quant_var_node, scale_var_node, ) = insert_quant_pass.insert_quant_op( - graph, var_node, var_name=name + graph, var_node, var_name=name, op_role=op_role ) dequant_var_node = insert_quant_pass.insert_dequant_op( - graph, quant_var_node, scale_var_node + graph, quant_var_node, scale_var_node, op_role ) self.dequantized_vars[name] = dequant_var_node @@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass): var_node, var_name=var_node.name(), scale_var_node=scale_var_node, + op_role=op.op().attr("op_role"), ) dequant_var_node = insert_quant_pass.insert_dequant_op( - graph, quant_var_node, scale_var_node + graph, + quant_var_node, + scale_var_node, + op.op().attr("op_role"), ) graph.update_input_link(var_node, dequant_var_node, op) @@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2: quant_var_node, scale_var_node, ) = insert_quant_pass.insert_quant_op( - graph, in_node + graph, + in_node, + op_role=op_node.op().attr("op_role"), ) dequant_var_node = ( insert_quant_pass.insert_dequant_op( - graph, quant_var_node, scale_var_node + graph, + quant_var_node, + scale_var_node, + op_node.op().attr("op_role"), ) ) dequantized_vars_map[arg_name] = dequant_var_node