From 7e71a785c7bae8bc43665607ceb39b66ce48a1f3 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sat, 13 Nov 2021 10:50:43 +0000 Subject: [PATCH] Add conv bn fusion pass to quanter --- paddleslim/quant/quanter.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 2522fed7..7e18421c 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -173,6 +173,42 @@ def _parse_configs(user_config): return configs +def _remove_unused_var_nodes(graph): + all_used_vars = set() + ops = graph.all_op_nodes() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } + graph.safe_remove_nodes(all_unused_vars) + return graph + +def _apply_pass(scope, graph, pass_name, attrs=None, + attr_values=None, debug=False): + ir_pass = core.get_pass(pass_name) + cpp_graph = graph.graph + if not cpp_graph.has('__param_scope__'): + cpp_graph.set_not_owned('__param_scope__', scope) + if attrs: + assert attr_values and len(attrs) == len( + attr_values + ), "Different number of pass attributes and their values." + for attr, value in zip(attrs, attr_values): + ir_pass.set(attr, value) + ir_pass.apply(cpp_graph) + if debug: + graph.draw('.', 'qat_fp32_{}'.format(pass_name), + graph.all_op_nodes()) + _remove_unused_var_nodes(graph) + return graph def quant_aware(program, place, @@ -242,6 +278,12 @@ def quant_aware(program, main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + + graph = _apply_pass(scope, main_graph, 'conv_bn_fuse_pass') + graph = _apply_pass(scope, main_graph, 'depthwise_conv_bn_fuse_pass') + graph = _apply_pass(scope, main_graph, 'conv_eltwiseadd_bn_fuse_pass') + graph = _apply_pass(scope, main_graph, 'depthwise_conv_eltwiseadd_bn_fuse_pass') + transform_pass_ops = [] quant_dequant_ops = [] for op_type in config['quantize_op_types']: -- GitLab