提交 edfd38be 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix node naming in the flattened graph

GitOrigin-RevId: aa7c516725f5c0f97dc176381d85e4b0f91981dc
上级 b8776574
......@@ -1122,16 +1122,25 @@ class InternalGraph:
self.__dict__.update(state)
if old_version:
self.inputs[0]._qualname = self._qualname
for e in self.exprs(False):
if isinstance(e, GetAttr):
e.outputs[0]._qualname = "{}.{}".format(
e.inputs[0]._qualname, e.name
)
for n in self.nodes(False):
qualname = self._qualname
if isinstance(n.expr, CallMethod) and isinstance(
n.expr.inputs[0], ModuleNode
):
n._qualname = n.expr.inputs[0]._qualname + ".[out]"
continue
if n._qualname:
qualname = "{}.{}".format(qualname, n._qualname)
n._qualname = qualname
if (
not isinstance(n.expr, GetAttr)
and isinstance(n, TensorNode)
and n._qualname
):
n._qualname = "{}.{}".format(self._qualname, n._qualname)
self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name()
......@@ -2080,8 +2089,10 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module
prefix_name = call.inputs[0]._name if call else ""
exprs = []
flattened_exprs = []
for expr in graph._exprs:
exprs = [expr]
if call is not None:
_replace_inputs_and_outputs(expr, repl_dict)
......@@ -2102,10 +2113,7 @@ class TracedModule(Module):
else None
)
if expr_graph is not None:
exprs.extend(
_flatten_subgraph(graph, expr_graph, expr, obj)
)
continue
exprs = _flatten_subgraph(graph, expr_graph, expr, obj)
if parent_graph is not None:
for node in expr.outputs:
......@@ -2116,13 +2124,13 @@ class TracedModule(Module):
name, node
)
exprs.append(expr)
flattened_exprs.extend(exprs)
if call is not None:
for i in call.inputs:
i.users.remove(call)
return exprs
return flattened_exprs
new_module.graph._exprs = _flatten_subgraph(
None, new_module.graph, None, new_module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册