未验证 提交 5c19bfc8 编写于 作者: C ceci3 提交者: GitHub

support hybrid parallel in qat (#52219)

上级 6c01ce8a
......@@ -298,6 +298,7 @@ class QuantizationTransformPass:
def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
op.op()._set_attr("with_quant_attr", True)
op_role = op.op().attr("op_role")
inputs = op.inputs
for var_node in inputs:
if var_node.name() not in op.input_arg_names():
......@@ -368,7 +369,12 @@ class QuantizationTransformPass:
quant_var_node,
scale_var_node,
) = self._insert_channel_quant_op(
graph, var_node, name, quant_bits, quant_axis
graph,
var_node,
name,
quant_bits,
quant_axis,
op_role,
)
dequant_var_node = self._insert_channel_dequant_op(
graph,
......@@ -376,13 +382,23 @@ class QuantizationTransformPass:
[scale_var_node],
[quant_bits],
quant_axis,
op_role,
)
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, name, quant_bits, quant_type
graph,
var_node,
name,
quant_bits,
quant_type,
op_role,
)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits
graph,
quant_var_node,
scale_var_node,
quant_bits,
op_role,
)
dequantized_vars[name] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op)
......@@ -476,24 +492,28 @@ class QuantizationTransformPass:
graph.link_to(increment_op, global_step_out)
self._global_step = global_step_out
def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type):
def _insert_quant_op(
self, graph, var_node, name, quant_bits, quant_type, op_role
):
"""
Insert fake_quantize_op in the graph.
"""
if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)
elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)
def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
def _insert_quant_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
"""
Insert fake_quantize_abs_max op in the graph.
"""
......@@ -528,10 +548,7 @@ class QuantizationTransformPass:
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
attrs={'bit_length': quant_bits, 'op_role': op_role},
inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
)
......@@ -540,7 +557,9 @@ class QuantizationTransformPass:
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
def _insert_quant_range_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
"""
Insert fake_quantize_range_abs_max on the graph.
"""
......@@ -605,7 +624,7 @@ class QuantizationTransformPass:
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
......@@ -626,7 +645,7 @@ class QuantizationTransformPass:
return quant_var_node, scale_out_node
def _insert_quant_moving_average_abs_max_op(
self, graph, var_node, name, quant_bits
self, graph, var_node, name, quant_bits, op_role
):
"""Insert fake_quantize_moving_average_abs_max"""
quant_var_node = graph.create_var_node(
......@@ -706,7 +725,7 @@ class QuantizationTransformPass:
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}
quant_op_node = graph.create_op_node(
......@@ -730,7 +749,7 @@ class QuantizationTransformPass:
return quant_var_node, scale_out_node
def _insert_channel_quant_op(
self, graph, var_node, name, quant_bits, quant_axis
self, graph, var_node, name, quant_bits, quant_axis, op_role
):
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
......@@ -771,7 +790,7 @@ class QuantizationTransformPass:
'bit_length': quant_bits,
'quant_axis': quant_axis,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
},
inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
......@@ -781,7 +800,9 @@ class QuantizationTransformPass:
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
def _insert_dequant_op(
self, graph, var_node, scale_var_node, quant_bits, op_role
):
"""
Insert fake_dequantize_op in the graph.
"""
......@@ -796,10 +817,7 @@ class QuantizationTransformPass:
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
attrs={'max_range': float(max_range), 'op_role': op_role},
inputs={'X': var_node, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node},
)
......@@ -809,7 +827,7 @@ class QuantizationTransformPass:
return dequant_var_node
def _insert_channel_dequant_op(
self, graph, var_node, scale_var_nodes, quant_bits, quant_axis
self, graph, var_node, scale_var_nodes, quant_bits, quant_axis, op_role
):
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
......@@ -827,7 +845,7 @@ class QuantizationTransformPass:
attrs={
'quant_bits': quant_bits,
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
},
inputs={'X': var_node, 'Scales': scale_var_nodes},
outputs={'Out': dequant_var_node},
......@@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass:
in_node = graph._find_node_by_name(
op.outputs, output_var_name
)
if in_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]:
if (
in_node.dtype()
not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]
or '@GRAD' in in_node.name()
):
continue
if in_node.dtype() == core.VarDesc.VarType.FP64:
......@@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass:
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op.op().attr("op_role"),
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
......@@ -1953,7 +1975,10 @@ class AddQuantDequantPass:
quant_var_node,
_,
) = self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits
graph,
in_node,
self._quant_bits,
op_node.op().attr("op_role"),
)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(
......@@ -1978,7 +2003,7 @@ class AddQuantDequantPass:
return graph
def _inser_quant_dequant_moving_average_abs_max_op(
self, graph, var_node, quant_bits
self, graph, var_node, quant_bits, op_role
):
"""Insert fake_quantize_dequantize_moving_average_abs_max op."""
quant_var_node = graph.create_var_node(
......@@ -2068,7 +2093,7 @@ class AddQuantDequantPass:
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}
quant_op_node = graph.create_op_node(
......@@ -2131,7 +2156,12 @@ class InsertQuantizeLinear:
self._scale_dict = scale_dict
def insert_quant_op(
self, graph, var_node, var_name=None, scale_var_node=None
self,
graph,
var_node,
var_name=None,
scale_var_node=None,
op_role=core.op_proto_and_checker_maker.OpRole.Forward,
):
assert var_node.is_var(), f'{var_node.name()} is not a var'
var_name = var_node.name() if not var_name else var_name
......@@ -2200,7 +2230,7 @@ class InsertQuantizeLinear:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = op_role
outputs = {"Y": quant_var_node}
if not self._is_test:
scale_out_node = graph.create_var_node_from_desc(
......@@ -2271,7 +2301,7 @@ class InsertQuantizeLinear:
graph.link_to(quant_op_node, scale_out_node)
return quant_var_node, scale_var_node
def insert_dequant_op(self, graph, var_node, scale_var_node):
def insert_dequant_op(self, graph, var_node, scale_var_node, op_role):
assert var_node.is_var(), f'{var_node.name()} is not a var'
dequant_var_node = graph.create_var_node(
......@@ -2301,7 +2331,7 @@ class InsertQuantizeLinear:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = op_role
quant_op_node = graph.create_op_node(
op_type="dequantize_linear",
......@@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
def _transform_forward(self, graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
op_role = op.op().attr("op_role")
weight_scale_node = None
inputs = op.inputs
for var_node in inputs:
......@@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph, var_node, var_name=name
graph, var_node, var_name=name, op_role=op_role
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph, quant_var_node, scale_var_node, op_role
)
self.dequantized_vars[name] = dequant_var_node
......@@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
var_node,
var_name=var_node.name(),
scale_var_node=scale_var_node,
op_role=op.op().attr("op_role"),
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph,
quant_var_node,
scale_var_node,
op.op().attr("op_role"),
)
graph.update_input_link(var_node, dequant_var_node, op)
......@@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2:
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph, in_node
graph,
in_node,
op_role=op_node.op().attr("op_role"),
)
dequant_var_node = (
insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph,
quant_var_node,
scale_var_node,
op_node.op().attr("op_role"),
)
)
dequantized_vars_map[arg_name] = dequant_var_node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册