提交 c7a8d945 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): let graph record total_id

GitOrigin-RevId: f99178f3ac45b2fd12828fc65ef995ee369d308c
上级 8b40f577
...@@ -167,6 +167,15 @@ class Expr: ...@@ -167,6 +167,15 @@ class Expr:
state.pop("_top_graph") state.pop("_top_graph")
return state return state
@classmethod
def get_total_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
......
...@@ -42,10 +42,6 @@ class Node: ...@@ -42,10 +42,6 @@ class Node:
self._orig_name = orig_name self._orig_name = orig_name
self.actual_node = [] # type: List[Node] self.actual_node = [] # type: List[Node]
def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self): def __repr__(self):
format_spec = Node._format_spec format_spec = Node._format_spec
return self.__format__(format_spec) return self.__format__(format_spec)
...@@ -89,6 +85,15 @@ class Node: ...@@ -89,6 +85,15 @@ class Node:
cls._format_spec = str cls._format_spec = str
return old_format_spec return old_format_spec
@classmethod
def get_total_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
class ModuleNode(Node): class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects.""" r"""``ModuleNode`` represents the Module objects."""
......
...@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""): ...@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""):
class _InsertExprs: class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None): def __init__(self, graph, expr: Optional[Expr] = None):
self.graph = graph self.graph = graph
while graph.top_graph is not None:
graph = graph.top_graph
assert graph.inputs[0].owner._is_top
self.root_graph = graph
self.global_scope = InternalGraph( self.global_scope = InternalGraph(
graph._name, graph._prefix_name, graph._module_name graph._name, graph._prefix_name, graph._module_name
) )
...@@ -256,6 +260,9 @@ class _InsertExprs: ...@@ -256,6 +260,9 @@ class _InsertExprs:
def __enter__(self): def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True) self.use_sym_shape = set_symbolic_shape(True)
node_id, expr_id = self.root_graph._total_ids
Node.set_total_id(node_id)
Expr.set_total_id(expr_id)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(True) _set_convert_node_flag(True)
assert active_module_tracer() is None assert active_module_tracer() is None
...@@ -334,10 +341,8 @@ class _InsertExprs: ...@@ -334,10 +341,8 @@ class _InsertExprs:
insert_index += 1 insert_index += 1
self.graph._used_names.update(self.global_scope._used_names) self.graph._used_names.update(self.global_scope._used_names)
graph = self.graph self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id())
while graph.top_graph is not None: self.root_graph.inputs[0].owner._update_ref()
graph = graph.top_graph
graph.inputs[0].owner._update_ref()
return True return True
...@@ -353,7 +358,8 @@ class InternalGraph: ...@@ -353,7 +358,8 @@ class InternalGraph:
_exprs = None # type: List[Expr] _exprs = None # type: List[Expr]
_inputs = None # type: List[Node] _inputs = None # type: List[Node]
_outputs = None # type: List[Node] _outputs = None # type: List[Node]
_top_graph = None _top_graph = None # type: InternalGraph
_total_ids = None # type: List[int]
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
self._exprs = [] self._exprs = []
...@@ -704,8 +710,12 @@ class InternalGraph: ...@@ -704,8 +710,12 @@ class InternalGraph:
def replace_node(self, repl_dict: Dict[Node, Node]): def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict: while repl_dict:
node, repl_node = repl_dict.popitem() node, repl_node = repl_dict.popitem()
assert type(node) == type(
repl_node
), "The type of {}({}) and {}({}) are not the same".format(
node, type(node).__name__, repl_node, type(repl_node).__name__
)
# check graph inputs and outputs # check graph inputs and outputs
# assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs): for i, n in enumerate(self.outputs):
if n is node: if n is node:
self.outputs[i] = repl_node self.outputs[i] = repl_node
...@@ -713,7 +723,10 @@ class InternalGraph: ...@@ -713,7 +723,10 @@ class InternalGraph:
# update inputs of expr in node.users # update inputs of expr in node.users
graph = repl_node.top_graph graph = repl_node.top_graph
assert graph is not None assert graph is not None
index = graph._exprs.index(repl_node.expr) assert graph is self
index = -1
if not isinstance(repl_node.expr, Input):
index = graph._exprs.index(repl_node.expr)
dep_exprs = self.get_dep_exprs(repl_node) dep_exprs = self.get_dep_exprs(repl_node)
i = 0 i = 0
while i < len(node.users): while i < len(node.users):
...@@ -745,6 +758,13 @@ class InternalGraph: ...@@ -745,6 +758,13 @@ class InternalGraph:
n.users.remove(expr) n.users.remove(expr)
self._exprs.remove(expr) self._exprs.remove(expr)
def _reset_ids(self):
for total_expr_id, expr in enumerate(self.exprs()):
expr._id = total_expr_id
for total_node_id, node in enumerate(self.nodes()):
node._id = total_node_id
self._total_ids = (total_node_id + 1, total_expr_id + 1)
def interpret(self, *inputs): def interpret(self, *inputs):
node2value = {} node2value = {}
end_nodes_set = set(self._end_point) end_nodes_set = set(self._end_point)
...@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin): ...@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin):
) )
for _, g in self._argdef_graph_map.items(): for _, g in self._argdef_graph_map.items():
g.compile() g.compile()
if self._is_top:
g._total_ids = (Node.get_total_id(), Expr.get_total_id())
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
...@@ -1247,6 +1269,8 @@ class _expr_iter: ...@@ -1247,6 +1269,8 @@ class _expr_iter:
self.recursive = recursive self.recursive = recursive
def __iter__(self): def __iter__(self):
for inp_node in self.graph.inputs:
yield inp_node.expr
for expr in self.graph._exprs: for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr yield expr
...@@ -1262,10 +1286,10 @@ class _node_iter: ...@@ -1262,10 +1286,10 @@ class _node_iter:
node_ids = set() node_ids = set()
for expr in graph.exprs(recursive): for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs: for n in expr.inputs + expr.outputs:
if n._id in node_ids: if id(n) in node_ids:
continue continue
nodes.append(n) nodes.append(n)
node_ids.add(n._id) node_ids.add(id(n))
self.nodes = list(sorted(nodes, key=lambda x: x._id)) self.nodes = list(sorted(nodes, key=lambda x: x._id))
def __iter__(self): def __iter__(self):
...@@ -1546,6 +1570,7 @@ class TracedModule(Module): ...@@ -1546,6 +1570,7 @@ class TracedModule(Module):
active_module_tracer().push_scope(new_module.graph) active_module_tracer().push_scope(new_module.graph)
def _flatten_subgraph( def _flatten_subgraph(
parent_graph: InternalGraph,
graph: InternalGraph, graph: InternalGraph,
module: Module, module: Module,
call=None, call=None,
...@@ -1590,7 +1615,10 @@ class TracedModule(Module): ...@@ -1590,7 +1615,10 @@ class TracedModule(Module):
if inp is call_out: if inp is call_out:
expr.inputs[index] = repl_dict[out] expr.inputs[index] = repl_dict[out]
repl_dict[out].users.append(expr) repl_dict[out].users.append(expr)
if parent_graph is not None:
for index, parent_out in enumerate(parent_graph._outputs):
if parent_out is call_out:
parent_graph._outputs[index] = repl_dict[out]
continue continue
repl_dict[out] = call.outputs[ind] repl_dict[out] = call.outputs[ind]
...@@ -1622,6 +1650,7 @@ class TracedModule(Module): ...@@ -1622,6 +1650,7 @@ class TracedModule(Module):
) )
exprs.extend( exprs.extend(
_flatten_subgraph( _flatten_subgraph(
graph,
expr_graph, expr_graph,
obj, obj,
expr, expr,
...@@ -1643,19 +1672,10 @@ class TracedModule(Module): ...@@ -1643,19 +1672,10 @@ class TracedModule(Module):
i.users.remove(call) i.users.remove(call)
return exprs return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module)
new_module.graph.compile() new_module.graph.compile()
set_active_module_tracer(None) set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs): new_module.graph._reset_ids()
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
return new_module return new_module
def __getstate__(self): def __getstate__(self):
...@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer( set_active_module_tracer(
module_tracer(_wrapped_function, _init_id2name(mod, "self")) module_tracer(_wrapped_function, _init_id2name(mod, "self"))
) )
for cls in [Expr, Node]:
cls.set_total_id(0)
with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph(name="") global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
...@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
) )
builder(*args, **kwargs) builder(*args, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
return builder.build() traced_mod = builder.build()
traced_mod.graph._reset_ids()
return traced_mod
finally: finally:
set_symbolic_shape(use_sym_shape) set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None) set_active_module_tracer(None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册