From 3398f996080f0a8268207a06be0d67c8a3b433d9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 22 May 2019 10:56:44 +0800 Subject: [PATCH] Adding AddQuantDequantPass for TensorRT int8 (#17529) * add quant_dequant_pass, test=develop * Add quant_dequant before some ops, such as the elementwise_add op. This is required by TensorRT. test=develop --- .../slim/quantization/quantization_pass.py | 141 +++++++++++++++++- .../tests/test_quantization_scale_pass.py | 8 + 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 0d989903a9a..1ea2f080c64 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -22,7 +22,8 @@ from .... import unique_name __all__ = [ 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', - 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass' + 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass', + 'AddQuantDequantPass' ] @@ -994,6 +995,8 @@ class ScaleForTrainingPass(object): Args: graph(IrGraph): the target graph. """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' self._is_test = graph.is_test() ops = graph.all_op_nodes() for op_node in ops: @@ -1099,6 +1102,8 @@ class ScaleForInferencePass(object): Args: graph(IrGraph): the target graph. """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' ops = graph.all_op_nodes() for op_node in ops: name = op_node.name() @@ -1117,3 +1122,137 @@ class ScaleForInferencePass(object): Return the scale name for the var named `var_name`. """ return "%s@scale" % (var_name) + + +class AddQuantDequantPass(object): + def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8): + """ + This pass is used to add quant_dequant op for some ops, such as the + `elementwise_add` op. + """ + self._scope = scope + self._place = place + self._moving_rate = moving_rate + self._quant_bits = quant_bits + self._is_test = None + self._target_ops = ["elementwise_add", "pool2d"] + + def apply(self, graph): + """ + Add quant_dequant before some ops, such as the `elementwise_add` op. This + is required by TensorRT. + Args: + graph(IrGraph): the target graph. + """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + self._is_test = graph.is_test() + ops = graph.all_op_nodes() + for op_node in ops: + name = op_node.name() + if name in self._target_ops: + in_nodes_all_not_persistable = True + for input_name in op_node.input_arg_names(): + in_node = graph._find_node_by_name(op_node.inputs, + input_name) + in_nodes_all_not_persistable = ( + in_nodes_all_not_persistable and + not in_node.persistable()) + if not in_nodes_all_not_persistable: + continue + input_names = op_node.input_arg_names() + for input_name in input_names: + in_node = graph._find_node_by_name(op_node.inputs, + input_name) + quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op( + graph, in_node, self._quant_bits) + graph.update_input_link(in_node, quant_var_node, op_node) + graph.resolve_hazard() + return graph + + def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, + quant_bits): + """Insert fake_quantize_dequantize_moving_average_abs_max op. + """ + quant_var_node = graph.create_var_node( + name="{}.quant_dequant".format(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + scale_in_node = graph.create_persistable_node( + name="{}.quant_dequant.scale".format(var_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_in_node, + np.array( + [0.001], dtype=data_type), + self._scope, + self._place) + + scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) + ins = {'X': var_node, 'InScale': scale_in_node} + outs = {'Out': quant_var_node, 'OutScale': scale_out_node} + if not self._is_test: + state_in_node = graph.create_persistable_node( + name=unique_name.generate('quant_dequant.state'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=var_node.dtype(), + shape=[1]) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + state_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) + accum_in_node = graph.create_persistable_node( + name=unique_name.generate('quant_dequant.accum'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=var_node.dtype(), + shape=[1]) + _init_var_node( + accum_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) + state_out_node = graph.create_var_node_from_desc(state_in_node.var( + )) + accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( + )) + + ins['InState'] = state_in_node + ins['InAccum'] = accum_in_node + outs['OutState'] = state_out_node + outs['OutAccum'] = accum_out_node + + attrs = { + 'bit_length': quant_bits, + 'moving_rate': self._moving_rate, + 'is_test': self._is_test, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + } + + quant_op_node = graph.create_op_node( + op_type='fake_quantize_dequantize_moving_average_abs_max', + attrs=attrs, + inputs=ins, + outputs=outs) + + graph.link_to(var_node, quant_op_node) + graph.link_to(scale_in_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_out_node) + + if not self._is_test: + graph.link_to(state_in_node, quant_op_node) + graph.link_to(accum_in_node, quant_op_node) + graph.link_to(quant_op_node, state_out_node) + graph.link_to(quant_op_node, accum_out_node) + + return quant_var_node, scale_out_node diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py index 1ed41da0f84..0739c9c1f7b 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py @@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -98,6 +99,7 @@ class TestQuantizationScalePass(unittest.TestCase): scope = fluid.Scope() with fluid.scope_guard(scope): exe.run(startup) + transform_pass = QuantizationTransformPass( scope=scope, place=place, @@ -105,8 +107,14 @@ class TestQuantizationScalePass(unittest.TestCase): weight_quantize_type=weight_quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) + + add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place) + add_quant_dequant_pass.apply(main_graph) + add_quant_dequant_pass.apply(test_graph) + scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) scale_training_pass.apply(main_graph) + dev_name = '_gpu' if use_cuda else '_cpu' if not for_ci: marked_nodes = set() -- GitLab