提交 edfd38be 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix node naming in the flattened graph

GitOrigin-RevId: aa7c516725f5c0f97dc176381d85e4b0f91981dc
上级 b8776574
...@@ -1122,16 +1122,25 @@ class InternalGraph: ...@@ -1122,16 +1122,25 @@ class InternalGraph:
self.__dict__.update(state) self.__dict__.update(state)
if old_version: if old_version:
self.inputs[0]._qualname = self._qualname
for e in self.exprs(False):
if isinstance(e, GetAttr):
e.outputs[0]._qualname = "{}.{}".format(
e.inputs[0]._qualname, e.name
)
for n in self.nodes(False): for n in self.nodes(False):
qualname = self._qualname
if isinstance(n.expr, CallMethod) and isinstance( if isinstance(n.expr, CallMethod) and isinstance(
n.expr.inputs[0], ModuleNode n.expr.inputs[0], ModuleNode
): ):
n._qualname = n.expr.inputs[0]._qualname + ".[out]" n._qualname = n.expr.inputs[0]._qualname + ".[out]"
continue continue
if n._qualname: if (
qualname = "{}.{}".format(qualname, n._qualname) not isinstance(n.expr, GetAttr)
n._qualname = qualname and isinstance(n, TensorNode)
and n._qualname
):
n._qualname = "{}.{}".format(self._qualname, n._qualname)
self._namespace = NameSpace(self._name, self._qualname) self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name() self._re_associate_name()
...@@ -2080,8 +2089,10 @@ class TracedModule(Module): ...@@ -2080,8 +2089,10 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module node2obj[graph._inputs[0]] = module
prefix_name = call.inputs[0]._name if call else "" prefix_name = call.inputs[0]._name if call else ""
exprs = [] flattened_exprs = []
for expr in graph._exprs: for expr in graph._exprs:
exprs = [expr]
if call is not None: if call is not None:
_replace_inputs_and_outputs(expr, repl_dict) _replace_inputs_and_outputs(expr, repl_dict)
...@@ -2102,10 +2113,7 @@ class TracedModule(Module): ...@@ -2102,10 +2113,7 @@ class TracedModule(Module):
else None else None
) )
if expr_graph is not None: if expr_graph is not None:
exprs.extend( exprs = _flatten_subgraph(graph, expr_graph, expr, obj)
_flatten_subgraph(graph, expr_graph, expr, obj)
)
continue
if parent_graph is not None: if parent_graph is not None:
for node in expr.outputs: for node in expr.outputs:
...@@ -2116,13 +2124,13 @@ class TracedModule(Module): ...@@ -2116,13 +2124,13 @@ class TracedModule(Module):
name, node name, node
) )
exprs.append(expr) flattened_exprs.extend(exprs)
if call is not None: if call is not None:
for i in call.inputs: for i in call.inputs:
i.users.remove(call) i.users.remove(call)
return exprs return flattened_exprs
new_module.graph._exprs = _flatten_subgraph( new_module.graph._exprs = _flatten_subgraph(
None, new_module.graph, None, new_module None, new_module.graph, None, new_module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册