From b0ceed6fb4f15891599e0fe62f971a623f949da8 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 23 Sep 2019 21:22:31 +0800 Subject: [PATCH] add fake_quant_dequant_op for average pool2d, test=develop (#19880) * add fake_quant_dequant_op for average pool2d * add test --- .../slim/quantization/quantization_pass.py | 37 +++++++++-- .../slim/tests/test_quantization_pass.py | 65 ++++++++++++++++++- 2 files changed, 94 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index d65e0e8f0ca..15a91c063d0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -90,6 +90,9 @@ class QuantizationTransformPass(object): usually is not used for weight, since weights are fixed once the model is well trained. window_size (int): the window size for 'range_abs_max' quantization. + skip_pattern(str): The user-defined quantization skip pattern, which + will be presented in the name scope of an op. When the skip pattern is + detected in an op's name scope, the corresponding op will not be quantized. Examples: .. code-block:: python @@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object): def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8): """ This pass is used to add quant_dequant op for some ops, such as the - `elementwise_add` op. + 'elementwise_add' and 'average pool2d' op. """ self._scope = scope self._place = place self._moving_rate = moving_rate self._quant_bits = quant_bits self._is_test = None - self._target_ops = ["elementwise_add"] + self._target_ops = ["elementwise_add", "pool2d"] + self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops] def apply(self, graph): """ - Add quant_dequant before some ops, such as the `elementwise_add` op. This - is required by TensorRT. + Add quant_dequant before some ops, such as the 'elementwise_add' + and 'average pool2d' op. Args: graph(IrGraph): the target graph. """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' self._is_test = graph.is_test() + dequantized_vars_map = collections.OrderedDict() ops = graph.all_op_nodes() + for op_node in ops: - name = op_node.name() - if name in self._target_ops: + if op_node.name() in self._target_ops: in_nodes_all_not_persistable = True for input_name in op_node.input_arg_names(): in_node = graph._find_node_by_name(op_node.inputs, @@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object): not in_node.persistable()) if not in_nodes_all_not_persistable: continue + + if op_node.op().has_attr("pooling_type") and \ + op_node.op().attr("pooling_type") == 'max': + continue + input_names = op_node.input_arg_names() for input_name in input_names: in_node = graph._find_node_by_name(op_node.inputs, input_name) - quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op( + quant_var_node, scale_var_node = \ + self._inser_quant_dequant_moving_average_abs_max_op( graph, in_node, self._quant_bits) + dequantized_vars_map[input_name] = quant_var_node graph.update_input_link(in_node, quant_var_node, op_node) + + for op_node in ops: + if op_node.name() in self._target_grad_ops: + for input_name in op_node.input_arg_names(): + if input_name in dequantized_vars_map: + in_node = graph._find_node_by_name(op_node.inputs, + input_name) + dequant_var_node = dequantized_vars_map[input_name] + graph.update_input_link(in_node, dequant_var_node, + op_node) + graph.resolve_hazard() return graph diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index d4138684165..162048d7440 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -66,7 +67,9 @@ def residual_block(num): conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10) + pool = fluid.layers.pool2d( + input=hidden, pool_size=2, pool_type='avg', pool_stride=2) + fc = fluid.layers.fc(input=pool, size=10) loss = fluid.layers.cross_entropy(input=fc, label=label) loss = fluid.layers.mean(loss) return loss @@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase): for_ci=True) +class TestAddQuantDequantPass(unittest.TestCase): + def setUp(self): + self._target_ops = {'elementwise_add', 'pool2d'} + self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'} + + def check_graph(self, graph): + ops = graph.all_op_nodes() + + for op_node in ops: + if op_node.name() in self._target_ops: + in_nodes_all_not_persistable = True + for input_name in op_node.input_arg_names(): + in_node = graph._find_node_by_name(op_node.inputs, + input_name) + in_nodes_all_not_persistable = ( + in_nodes_all_not_persistable and + not in_node.persistable()) + if not in_nodes_all_not_persistable: + continue + + if op_node.op().has_attr("pooling_type") and \ + op_node.op().attr("pooling_type") == 'max': + continue + + input_names = op_node.input_arg_names() + for input_name in input_names: + self.assertTrue(input_name.endswith('.quant_dequant')) + + def residual_block_quant(self, for_ci=True): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(1) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + place = fluid.CPUPlace() + graph = IrGraph(core.Graph(main.desc), for_test=False) + add_quant_dequant_pass = AddQuantDequantPass( + scope=fluid.global_scope(), place=place) + add_quant_dequant_pass.apply(graph) + if not for_ci: + marked_nodes = set() + for op in graph.all_op_nodes(): + if op.name().find('quant') > -1: + marked_nodes.add(op) + graph.draw('.', 'add_quant_dequant_graph', marked_nodes) + self.check_graph(graph) + program = graph.to_program() + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + if not for_ci: + val_marked_nodes = set() + for op in val_graph.all_op_nodes(): + if op.name().find('quant') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes) + + def test_residual_block(self): + self.residual_block_quant(for_ci=True) + + if __name__ == '__main__': unittest.main() -- GitLab