提交 1c11f817 编写于 作者: Z Zhen Wang

Use the resolve hazard method.

上级 2ccbfd5e
...@@ -26,35 +26,6 @@ __all__ = [ ...@@ -26,35 +26,6 @@ __all__ = [
] ]
def _resolve_hazard(graph):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n.node
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = graph.topology_sort()
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
graph.graph.resolve_hazard(var_nodes)
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
def __init__(self, def __init__(self,
scope=None, scope=None,
...@@ -150,8 +121,8 @@ class QuantizationTransformPass(object): ...@@ -150,8 +121,8 @@ class QuantizationTransformPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
sequential_execution = core.get_pass('sequential_execution_pass') #sequential_execution = core.get_pass('sequential_execution_pass')
sequential_execution.apply(graph.graph) #sequential_execution.apply(graph.graph)
self._is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
...@@ -216,7 +187,7 @@ class QuantizationTransformPass(object): ...@@ -216,7 +187,7 @@ class QuantizationTransformPass(object):
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops: if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op) _transform_backward(graph, op)
_resolve_hazard(graph) graph.resolve_hazard()
return graph return graph
def _create_global_step(self, graph): def _create_global_step(self, graph):
...@@ -652,6 +623,7 @@ class QuantizationFreezePass(object): ...@@ -652,6 +623,7 @@ class QuantizationFreezePass(object):
# remove the unused var node in the graph # remove the unused var node in the graph
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node): def _remove_fake_quant_and_dequant_op(self, graph, op_node):
...@@ -895,6 +867,7 @@ class ConvertToInt8Pass(object): ...@@ -895,6 +867,7 @@ class ConvertToInt8Pass(object):
# remove the unused var node in the graph # remove the unused var node in the graph
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph return graph
def _convert_to_int8(self, graph, var_node): def _convert_to_int8(self, graph, var_node):
...@@ -977,5 +950,5 @@ class TransformForMobilePass(object): ...@@ -977,5 +950,5 @@ class TransformForMobilePass(object):
for output_node in op_node.outputs: for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node) graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
graph.resolve_hazard()
return graph return graph
...@@ -2253,6 +2253,34 @@ class IrGraph(object): ...@@ -2253,6 +2253,34 @@ class IrGraph(object):
original_nodes = {n.node for n in remove_nodes} original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes) core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self): def has_circle(self):
""" """
Check if the graph has a circle. Check if the graph has a circle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册