提交 1c11f817 编写于 作者: Z Zhen Wang

Use the resolve hazard method.

上级 2ccbfd5e
......@@ -26,35 +26,6 @@ __all__ = [
]
def _resolve_hazard(graph):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n.node
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = graph.topology_sort()
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
graph.graph.resolve_hazard(var_nodes)
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
......@@ -150,8 +121,8 @@ class QuantizationTransformPass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
sequential_execution = core.get_pass('sequential_execution_pass')
sequential_execution.apply(graph.graph)
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
......@@ -216,7 +187,7 @@ class QuantizationTransformPass(object):
for op in ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)
_resolve_hazard(graph)
graph.resolve_hazard()
return graph
def _create_global_step(self, graph):
......@@ -652,6 +623,7 @@ class QuantizationFreezePass(object):
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
......@@ -895,6 +867,7 @@ class ConvertToInt8Pass(object):
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph
def _convert_to_int8(self, graph, var_node):
......@@ -977,5 +950,5 @@ class TransformForMobilePass(object):
for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node)
graph.resolve_hazard()
return graph
......@@ -2253,6 +2253,34 @@ class IrGraph(object):
original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self):
"""
Check if the graph has a circle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册