提交 c7a8d945 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): let graph record total_id

GitOrigin-RevId: f99178f3ac45b2fd12828fc65ef995ee369d308c
上级 8b40f577
......@@ -167,6 +167,15 @@ class Expr:
state.pop("_top_graph")
return state
@classmethod
def get_total_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
......
......@@ -42,10 +42,6 @@ class Node:
self._orig_name = orig_name
self.actual_node = [] # type: List[Node]
def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self):
format_spec = Node._format_spec
return self.__format__(format_spec)
......@@ -89,6 +85,15 @@ class Node:
cls._format_spec = str
return old_format_spec
@classmethod
def get_total_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects."""
......
......@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""):
class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None):
self.graph = graph
while graph.top_graph is not None:
graph = graph.top_graph
assert graph.inputs[0].owner._is_top
self.root_graph = graph
self.global_scope = InternalGraph(
graph._name, graph._prefix_name, graph._module_name
)
......@@ -256,6 +260,9 @@ class _InsertExprs:
def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True)
node_id, expr_id = self.root_graph._total_ids
Node.set_total_id(node_id)
Expr.set_total_id(expr_id)
set_module_tracing()
_set_convert_node_flag(True)
assert active_module_tracer() is None
......@@ -334,10 +341,8 @@ class _InsertExprs:
insert_index += 1
self.graph._used_names.update(self.global_scope._used_names)
graph = self.graph
while graph.top_graph is not None:
graph = graph.top_graph
graph.inputs[0].owner._update_ref()
self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id())
self.root_graph.inputs[0].owner._update_ref()
return True
......@@ -353,7 +358,8 @@ class InternalGraph:
_exprs = None # type: List[Expr]
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]
_top_graph = None
_top_graph = None # type: InternalGraph
_total_ids = None # type: List[int]
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
self._exprs = []
......@@ -704,8 +710,12 @@ class InternalGraph:
def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(
repl_node
), "The type of {}({}) and {}({}) are not the same".format(
node, type(node).__name__, repl_node, type(repl_node).__name__
)
# check graph inputs and outputs
# assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
......@@ -713,7 +723,10 @@ class InternalGraph:
# update inputs of expr in node.users
graph = repl_node.top_graph
assert graph is not None
index = graph._exprs.index(repl_node.expr)
assert graph is self
index = -1
if not isinstance(repl_node.expr, Input):
index = graph._exprs.index(repl_node.expr)
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
......@@ -745,6 +758,13 @@ class InternalGraph:
n.users.remove(expr)
self._exprs.remove(expr)
def _reset_ids(self):
for total_expr_id, expr in enumerate(self.exprs()):
expr._id = total_expr_id
for total_node_id, node in enumerate(self.nodes()):
node._id = total_node_id
self._total_ids = (total_node_id + 1, total_expr_id + 1)
def interpret(self, *inputs):
node2value = {}
end_nodes_set = set(self._end_point)
......@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin):
)
for _, g in self._argdef_graph_map.items():
g.compile()
if self._is_top:
g._total_ids = (Node.get_total_id(), Expr.get_total_id())
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
......@@ -1247,6 +1269,8 @@ class _expr_iter:
self.recursive = recursive
def __iter__(self):
for inp_node in self.graph.inputs:
yield inp_node.expr
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
......@@ -1262,10 +1286,10 @@ class _node_iter:
node_ids = set()
for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs:
if n._id in node_ids:
if id(n) in node_ids:
continue
nodes.append(n)
node_ids.add(n._id)
node_ids.add(id(n))
self.nodes = list(sorted(nodes, key=lambda x: x._id))
def __iter__(self):
......@@ -1546,6 +1570,7 @@ class TracedModule(Module):
active_module_tracer().push_scope(new_module.graph)
def _flatten_subgraph(
parent_graph: InternalGraph,
graph: InternalGraph,
module: Module,
call=None,
......@@ -1590,7 +1615,10 @@ class TracedModule(Module):
if inp is call_out:
expr.inputs[index] = repl_dict[out]
repl_dict[out].users.append(expr)
if parent_graph is not None:
for index, parent_out in enumerate(parent_graph._outputs):
if parent_out is call_out:
parent_graph._outputs[index] = repl_dict[out]
continue
repl_dict[out] = call.outputs[ind]
......@@ -1622,6 +1650,7 @@ class TracedModule(Module):
)
exprs.extend(
_flatten_subgraph(
graph,
expr_graph,
obj,
expr,
......@@ -1643,19 +1672,10 @@ class TracedModule(Module):
i.users.remove(call)
return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module)
new_module.graph.compile()
set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs):
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
new_module.graph._reset_ids()
return new_module
def __getstate__(self):
......@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer(
module_tracer(_wrapped_function, _init_id2name(mod, "self"))
)
for cls in [Expr, Node]:
cls.set_total_id(0)
with active_module_tracer().patcher:
global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope)
......@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
)
builder(*args, **kwargs)
active_module_tracer().pop_scope()
return builder.build()
traced_mod = builder.build()
traced_mod.graph._reset_ids()
return traced_mod
finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册