提交 7e71a785 编写于 作者: W wanghaoshuang

Add conv bn fusion pass to quanter

上级 b1e04ba7
...@@ -173,6 +173,42 @@ def _parse_configs(user_config): ...@@ -173,6 +173,42 @@ def _parse_configs(user_config):
return configs return configs
def _remove_unused_var_nodes(graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _apply_pass(scope, graph, pass_name, attrs=None,
attr_values=None, debug=False):
ir_pass = core.get_pass(pass_name)
cpp_graph = graph.graph
if not cpp_graph.has('__param_scope__'):
cpp_graph.set_not_owned('__param_scope__', scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(cpp_graph)
if debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name),
graph.all_op_nodes())
_remove_unused_var_nodes(graph)
return graph
def quant_aware(program, def quant_aware(program,
place, place,
...@@ -242,6 +278,12 @@ def quant_aware(program, ...@@ -242,6 +278,12 @@ def quant_aware(program,
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
graph = _apply_pass(scope, main_graph, 'conv_bn_fuse_pass')
graph = _apply_pass(scope, main_graph, 'depthwise_conv_bn_fuse_pass')
graph = _apply_pass(scope, main_graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = _apply_pass(scope, main_graph, 'depthwise_conv_eltwiseadd_bn_fuse_pass')
transform_pass_ops = [] transform_pass_ops = []
quant_dequant_ops = [] quant_dequant_ops = []
for op_type in config['quantize_op_types']: for op_type in config['quantize_op_types']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册