未验证 提交 5c19bfc8 编写于 作者: C ceci3 提交者: GitHub

support hybrid parallel in qat (#52219)

上级 6c01ce8a
...@@ -298,6 +298,7 @@ class QuantizationTransformPass: ...@@ -298,6 +298,7 @@ class QuantizationTransformPass:
def _transform_forward(graph, op): def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight") op.op()._set_attr("quantization_type", "qat_with_weight")
op.op()._set_attr("with_quant_attr", True) op.op()._set_attr("with_quant_attr", True)
op_role = op.op().attr("op_role")
inputs = op.inputs inputs = op.inputs
for var_node in inputs: for var_node in inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
...@@ -368,7 +369,12 @@ class QuantizationTransformPass: ...@@ -368,7 +369,12 @@ class QuantizationTransformPass:
quant_var_node, quant_var_node,
scale_var_node, scale_var_node,
) = self._insert_channel_quant_op( ) = self._insert_channel_quant_op(
graph, var_node, name, quant_bits, quant_axis graph,
var_node,
name,
quant_bits,
quant_axis,
op_role,
) )
dequant_var_node = self._insert_channel_dequant_op( dequant_var_node = self._insert_channel_dequant_op(
graph, graph,
...@@ -376,13 +382,23 @@ class QuantizationTransformPass: ...@@ -376,13 +382,23 @@ class QuantizationTransformPass:
[scale_var_node], [scale_var_node],
[quant_bits], [quant_bits],
quant_axis, quant_axis,
op_role,
) )
else: else:
quant_var_node, scale_var_node = self._insert_quant_op( quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, name, quant_bits, quant_type graph,
var_node,
name,
quant_bits,
quant_type,
op_role,
) )
dequant_var_node = self._insert_dequant_op( dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits graph,
quant_var_node,
scale_var_node,
quant_bits,
op_role,
) )
dequantized_vars[name] = dequant_var_node dequantized_vars[name] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
...@@ -476,24 +492,28 @@ class QuantizationTransformPass: ...@@ -476,24 +492,28 @@ class QuantizationTransformPass:
graph.link_to(increment_op, global_step_out) graph.link_to(increment_op, global_step_out)
self._global_step = global_step_out self._global_step = global_step_out
def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type): def _insert_quant_op(
self, graph, var_node, name, quant_bits, quant_type, op_role
):
""" """
Insert fake_quantize_op in the graph. Insert fake_quantize_op in the graph.
""" """
if quant_type == 'abs_max': if quant_type == 'abs_max':
return self._insert_quant_abs_max_op( return self._insert_quant_abs_max_op(
graph, var_node, name, quant_bits graph, var_node, name, quant_bits, op_role
) )
elif quant_type == 'range_abs_max': elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op( return self._insert_quant_range_abs_max_op(
graph, var_node, name, quant_bits graph, var_node, name, quant_bits, op_role
) )
elif quant_type == 'moving_average_abs_max': elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op( return self._insert_quant_moving_average_abs_max_op(
graph, var_node, name, quant_bits graph, var_node, name, quant_bits, op_role
) )
def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits): def _insert_quant_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
""" """
Insert fake_quantize_abs_max op in the graph. Insert fake_quantize_abs_max op in the graph.
""" """
...@@ -528,10 +548,7 @@ class QuantizationTransformPass: ...@@ -528,10 +548,7 @@ class QuantizationTransformPass:
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={'bit_length': quant_bits, 'op_role': op_role},
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
) )
...@@ -540,7 +557,9 @@ class QuantizationTransformPass: ...@@ -540,7 +557,9 @@ class QuantizationTransformPass:
graph.link_to(quant_op_node, scale_var_node) graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node return quant_var_node, scale_var_node
def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits): def _insert_quant_range_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
""" """
Insert fake_quantize_range_abs_max on the graph. Insert fake_quantize_range_abs_max on the graph.
""" """
...@@ -605,7 +624,7 @@ class QuantizationTransformPass: ...@@ -605,7 +624,7 @@ class QuantizationTransformPass:
'window_size': self._window_size, 'window_size': self._window_size,
'bit_length': quant_bits, 'bit_length': quant_bits,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op_role,
} }
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max', op_type='fake_quantize_range_abs_max',
...@@ -626,7 +645,7 @@ class QuantizationTransformPass: ...@@ -626,7 +645,7 @@ class QuantizationTransformPass:
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
def _insert_quant_moving_average_abs_max_op( def _insert_quant_moving_average_abs_max_op(
self, graph, var_node, name, quant_bits self, graph, var_node, name, quant_bits, op_role
): ):
"""Insert fake_quantize_moving_average_abs_max""" """Insert fake_quantize_moving_average_abs_max"""
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
...@@ -706,7 +725,7 @@ class QuantizationTransformPass: ...@@ -706,7 +725,7 @@ class QuantizationTransformPass:
'bit_length': quant_bits, 'bit_length': quant_bits,
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op_role,
} }
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
...@@ -730,7 +749,7 @@ class QuantizationTransformPass: ...@@ -730,7 +749,7 @@ class QuantizationTransformPass:
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
def _insert_channel_quant_op( def _insert_channel_quant_op(
self, graph, var_node, name, quant_bits, quant_axis self, graph, var_node, name, quant_bits, quant_axis, op_role
): ):
""" """
Insert fake_channel_wise_quantize_abs_max op in the graph. Insert fake_channel_wise_quantize_abs_max op in the graph.
...@@ -771,7 +790,7 @@ class QuantizationTransformPass: ...@@ -771,7 +790,7 @@ class QuantizationTransformPass:
'bit_length': quant_bits, 'bit_length': quant_bits,
'quant_axis': quant_axis, 'quant_axis': quant_axis,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op_role,
}, },
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
...@@ -781,7 +800,9 @@ class QuantizationTransformPass: ...@@ -781,7 +800,9 @@ class QuantizationTransformPass:
graph.link_to(quant_op_node, scale_var_node) graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node return quant_var_node, scale_var_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): def _insert_dequant_op(
self, graph, var_node, scale_var_node, quant_bits, op_role
):
""" """
Insert fake_dequantize_op in the graph. Insert fake_dequantize_op in the graph.
""" """
...@@ -796,10 +817,7 @@ class QuantizationTransformPass: ...@@ -796,10 +817,7 @@ class QuantizationTransformPass:
max_range = (1 << (quant_bits - 1)) - 1 max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
attrs={ attrs={'max_range': float(max_range), 'op_role': op_role},
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
inputs={'X': var_node, 'Scale': scale_var_node}, inputs={'X': var_node, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node}, outputs={'Out': dequant_var_node},
) )
...@@ -809,7 +827,7 @@ class QuantizationTransformPass: ...@@ -809,7 +827,7 @@ class QuantizationTransformPass:
return dequant_var_node return dequant_var_node
def _insert_channel_dequant_op( def _insert_channel_dequant_op(
self, graph, var_node, scale_var_nodes, quant_bits, quant_axis self, graph, var_node, scale_var_nodes, quant_bits, quant_axis, op_role
): ):
""" """
Insert fake_channel_wise_dequantize_max_abs in the graph. Insert fake_channel_wise_dequantize_max_abs in the graph.
...@@ -827,7 +845,7 @@ class QuantizationTransformPass: ...@@ -827,7 +845,7 @@ class QuantizationTransformPass:
attrs={ attrs={
'quant_bits': quant_bits, 'quant_bits': quant_bits,
'quant_axis': quant_axis, 'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op_role,
}, },
inputs={'X': var_node, 'Scales': scale_var_nodes}, inputs={'X': var_node, 'Scales': scale_var_nodes},
outputs={'Out': dequant_var_node}, outputs={'Out': dequant_var_node},
...@@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass: ...@@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass:
in_node = graph._find_node_by_name( in_node = graph._find_node_by_name(
op.outputs, output_var_name op.outputs, output_var_name
) )
if in_node.dtype() not in [ if (
core.VarDesc.VarType.FP64, in_node.dtype()
core.VarDesc.VarType.FP32, not in [
core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP64,
]: core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]
or '@GRAD' in in_node.name()
):
continue continue
if in_node.dtype() == core.VarDesc.VarType.FP64: if in_node.dtype() == core.VarDesc.VarType.FP64:
...@@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass: ...@@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass:
attrs = { attrs = {
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op.op().attr("op_role"),
} }
scale_op_node = graph.create_op_node( scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale', op_type='moving_average_abs_max_scale',
...@@ -1953,7 +1975,10 @@ class AddQuantDequantPass: ...@@ -1953,7 +1975,10 @@ class AddQuantDequantPass:
quant_var_node, quant_var_node,
_, _,
) = self._inser_quant_dequant_moving_average_abs_max_op( ) = self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits graph,
in_node,
self._quant_bits,
op_node.op().attr("op_role"),
) )
dequantized_vars_map[arg_name] = quant_var_node dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link( graph.update_input_link(
...@@ -1978,7 +2003,7 @@ class AddQuantDequantPass: ...@@ -1978,7 +2003,7 @@ class AddQuantDequantPass:
return graph return graph
def _inser_quant_dequant_moving_average_abs_max_op( def _inser_quant_dequant_moving_average_abs_max_op(
self, graph, var_node, quant_bits self, graph, var_node, quant_bits, op_role
): ):
"""Insert fake_quantize_dequantize_moving_average_abs_max op.""" """Insert fake_quantize_dequantize_moving_average_abs_max op."""
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
...@@ -2068,7 +2093,7 @@ class AddQuantDequantPass: ...@@ -2068,7 +2093,7 @@ class AddQuantDequantPass:
'bit_length': quant_bits, 'bit_length': quant_bits,
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward, 'op_role': op_role,
} }
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
...@@ -2131,7 +2156,12 @@ class InsertQuantizeLinear: ...@@ -2131,7 +2156,12 @@ class InsertQuantizeLinear:
self._scale_dict = scale_dict self._scale_dict = scale_dict
def insert_quant_op( def insert_quant_op(
self, graph, var_node, var_name=None, scale_var_node=None self,
graph,
var_node,
var_name=None,
scale_var_node=None,
op_role=core.op_proto_and_checker_maker.OpRole.Forward,
): ):
assert var_node.is_var(), f'{var_node.name()} is not a var' assert var_node.is_var(), f'{var_node.name()} is not a var'
var_name = var_node.name() if not var_name else var_name var_name = var_node.name() if not var_name else var_name
...@@ -2200,7 +2230,7 @@ class InsertQuantizeLinear: ...@@ -2200,7 +2230,7 @@ class InsertQuantizeLinear:
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = op_role
outputs = {"Y": quant_var_node} outputs = {"Y": quant_var_node}
if not self._is_test: if not self._is_test:
scale_out_node = graph.create_var_node_from_desc( scale_out_node = graph.create_var_node_from_desc(
...@@ -2271,7 +2301,7 @@ class InsertQuantizeLinear: ...@@ -2271,7 +2301,7 @@ class InsertQuantizeLinear:
graph.link_to(quant_op_node, scale_out_node) graph.link_to(quant_op_node, scale_out_node)
return quant_var_node, scale_var_node return quant_var_node, scale_var_node
def insert_dequant_op(self, graph, var_node, scale_var_node): def insert_dequant_op(self, graph, var_node, scale_var_node, op_role):
assert var_node.is_var(), f'{var_node.name()} is not a var' assert var_node.is_var(), f'{var_node.name()} is not a var'
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
...@@ -2301,7 +2331,7 @@ class InsertQuantizeLinear: ...@@ -2301,7 +2331,7 @@ class InsertQuantizeLinear:
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = op_role
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type="dequantize_linear", op_type="dequantize_linear",
...@@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
def _transform_forward(self, graph, op): def _transform_forward(self, graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight") op.op()._set_attr("quantization_type", "qat_with_weight")
op_role = op.op().attr("op_role")
weight_scale_node = None weight_scale_node = None
inputs = op.inputs inputs = op.inputs
for var_node in inputs: for var_node in inputs:
...@@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
quant_var_node, quant_var_node,
scale_var_node, scale_var_node,
) = insert_quant_pass.insert_quant_op( ) = insert_quant_pass.insert_quant_op(
graph, var_node, var_name=name graph, var_node, var_name=name, op_role=op_role
) )
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node graph, quant_var_node, scale_var_node, op_role
) )
self.dequantized_vars[name] = dequant_var_node self.dequantized_vars[name] = dequant_var_node
...@@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
var_node, var_node,
var_name=var_node.name(), var_name=var_node.name(),
scale_var_node=scale_var_node, scale_var_node=scale_var_node,
op_role=op.op().attr("op_role"),
) )
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node graph,
quant_var_node,
scale_var_node,
op.op().attr("op_role"),
) )
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
...@@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2: ...@@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2:
quant_var_node, quant_var_node,
scale_var_node, scale_var_node,
) = insert_quant_pass.insert_quant_op( ) = insert_quant_pass.insert_quant_op(
graph, in_node graph,
in_node,
op_role=op_node.op().attr("op_role"),
) )
dequant_var_node = ( dequant_var_node = (
insert_quant_pass.insert_dequant_op( insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node graph,
quant_var_node,
scale_var_node,
op_node.op().attr("op_role"),
) )
) )
dequantized_vars_map[arg_name] = dequant_var_node dequantized_vars_map[arg_name] = dequant_var_node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册