From edfd38befdd15d2108209ef44b1f92c840fd0cee Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Nov 2021 16:33:08 +0800 Subject: [PATCH] fix(mge/traced_module): fix node naming in the flattened graph GitOrigin-RevId: aa7c516725f5c0f97dc176381d85e4b0f91981dc --- .../megengine/traced_module/traced_module.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 591b60d3f..8f5b3ce8d 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1122,16 +1122,25 @@ class InternalGraph: self.__dict__.update(state) if old_version: + self.inputs[0]._qualname = self._qualname + for e in self.exprs(False): + if isinstance(e, GetAttr): + e.outputs[0]._qualname = "{}.{}".format( + e.inputs[0]._qualname, e.name + ) + for n in self.nodes(False): - qualname = self._qualname if isinstance(n.expr, CallMethod) and isinstance( n.expr.inputs[0], ModuleNode ): n._qualname = n.expr.inputs[0]._qualname + ".[out]" continue - if n._qualname: - qualname = "{}.{}".format(qualname, n._qualname) - n._qualname = qualname + if ( + not isinstance(n.expr, GetAttr) + and isinstance(n, TensorNode) + and n._qualname + ): + n._qualname = "{}.{}".format(self._qualname, n._qualname) self._namespace = NameSpace(self._name, self._qualname) self._re_associate_name() @@ -2080,8 +2089,10 @@ class TracedModule(Module): node2obj[graph._inputs[0]] = module prefix_name = call.inputs[0]._name if call else "" - exprs = [] + flattened_exprs = [] + for expr in graph._exprs: + exprs = [expr] if call is not None: _replace_inputs_and_outputs(expr, repl_dict) @@ -2102,10 +2113,7 @@ class TracedModule(Module): else None ) if expr_graph is not None: - exprs.extend( - _flatten_subgraph(graph, expr_graph, expr, obj) - ) - continue + exprs = _flatten_subgraph(graph, expr_graph, expr, obj) if parent_graph is not None: for node in expr.outputs: @@ -2116,13 +2124,13 @@ class TracedModule(Module): name, node ) - exprs.append(expr) + flattened_exprs.extend(exprs) if call is not None: for i in call.inputs: i.users.remove(call) - return exprs + return flattened_exprs new_module.graph._exprs = _flatten_subgraph( None, new_module.graph, None, new_module -- GitLab