From 1c11f817e94e52bec84fbd423bb98b367cf7d4ec Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 26 Mar 2019 14:08:12 +0800 Subject: [PATCH] Use the resolve hazard method. --- .../slim/quantization/quantization_pass.py | 39 +++---------------- python/paddle/fluid/framework.py | 28 +++++++++++++ 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index bbf3d17d19c..ab3bd8bd182 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -26,35 +26,6 @@ __all__ = [ ] -def _resolve_hazard(graph): - def _to_node(nodes, node_name): - target_node = None - for n in nodes: - if n.name() == node_name: - target_node = n.node - assert target_node is not None, "Cannot find the target node in the giving set." - return target_node - - ordered_nodes = graph.topology_sort() - var_nodes = dict() - for node in ordered_nodes: - if node.is_op() and node.op() is not None: - for each_var_name in node.op().input_arg_names(): - if each_var_name not in var_nodes: - var_nodes[each_var_name] = [ - _to_node(node.inputs, each_var_name) - ] - for each_var_name in node.op().output_arg_names(): - if each_var_name not in var_nodes: - var_nodes[each_var_name] = [ - _to_node(node.outputs, each_var_name) - ] - else: - var_nodes[each_var_name].append( - _to_node(node.outputs, each_var_name)) - graph.graph.resolve_hazard(var_nodes) - - class QuantizationTransformPass(object): def __init__(self, scope=None, @@ -150,8 +121,8 @@ class QuantizationTransformPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - sequential_execution = core.get_pass('sequential_execution_pass') - sequential_execution.apply(graph.graph) + #sequential_execution = core.get_pass('sequential_execution_pass') + #sequential_execution.apply(graph.graph) self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() @@ -216,7 +187,7 @@ class QuantizationTransformPass(object): for op in ops: if op.name() in self._quantizable_grad_ops: _transform_backward(graph, op) - _resolve_hazard(graph) + graph.resolve_hazard() return graph def _create_global_step(self, graph): @@ -652,6 +623,7 @@ class QuantizationFreezePass(object): # remove the unused var node in the graph self._remove_unused_var_nodes(graph) + graph.resolve_hazard() return graph def _remove_fake_quant_and_dequant_op(self, graph, op_node): @@ -895,6 +867,7 @@ class ConvertToInt8Pass(object): # remove the unused var node in the graph self._remove_unused_var_nodes(graph) + graph.resolve_hazard() return graph def _convert_to_int8(self, graph, var_node): @@ -977,5 +950,5 @@ class TransformForMobilePass(object): for output_node in op_node.outputs: graph.link_to(dequant_node, output_node) graph.safe_remove_nodes(op_node) - + graph.resolve_hazard() return graph diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8097495f5bd..99929300012 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2253,6 +2253,34 @@ class IrGraph(object): original_nodes = {n.node for n in remove_nodes} core.graph_safe_remove_nodes(self.graph, original_nodes) + def resolve_hazard(self): + def _to_node(nodes, node_name): + target_node = None + for n in nodes: + if n.name() == node_name: + target_node = n + assert target_node is not None, "Cannot find the target node in the giving set." + return target_node + + ordered_nodes = core.topology_sort(self.graph) + var_nodes = dict() + for node in ordered_nodes: + if node.is_op() and node.op() is not None: + for each_var_name in node.op().input_arg_names(): + if each_var_name not in var_nodes: + var_nodes[each_var_name] = [ + _to_node(node.inputs, each_var_name) + ] + for each_var_name in node.op().output_arg_names(): + if each_var_name not in var_nodes: + var_nodes[each_var_name] = [ + _to_node(node.outputs, each_var_name) + ] + else: + var_nodes[each_var_name].append( + _to_node(node.outputs, each_var_name)) + self.graph.resolve_hazard(var_nodes) + def has_circle(self): """ Check if the graph has a circle. -- GitLab