提交 2d20f937 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix merge GetAttr failure when qualname is incorrect

GitOrigin-RevId: 46241e5b4f02ee09dd361c1fe4ea1d85acf49ef9
上级 e6c271ae
...@@ -74,7 +74,7 @@ class Node: ...@@ -74,7 +74,7 @@ class Node:
r"""Set a new name to this Node.""" r"""Set a new name to this Node."""
graph = self.top_graph graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None." assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._namespace.used_names, ( assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
......
...@@ -544,11 +544,11 @@ class InternalGraph: ...@@ -544,11 +544,11 @@ class InternalGraph:
graph = self.top_graph graph = self.top_graph
assert graph is not None or mod._is_top, "The parent graph cannot be None." assert graph is not None or mod._is_top, "The parent graph cannot be None."
if graph is not None: if graph is not None:
assert new_name not in self._namespace.used_names, ( assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
new_name = self._namespace.create_unique_name(new_name, self) new_name = graph._namespace.create_unique_name(new_name, self)
self._name = new_name self._name = new_name
@property @property
...@@ -1032,21 +1032,33 @@ class InternalGraph: ...@@ -1032,21 +1032,33 @@ class InternalGraph:
n.inputs[idx] = repl_node n.inputs[idx] = repl_node
def _merge_getattr_expr(self): def _merge_getattr_expr(self):
getattr_nodes_map = dict() getattr_nodes_map = dict() # Dcit[(Node, str), Node]
for expr in self._exprs: node_to_attrname = dict() # Dict[Node, (Node, Str)]
if not isinstance(expr, GetAttr): for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs):
continue base_node, attr_name = expr.inputs[0], expr.name
attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname) if expr.inputs[0] in node_to_attrname:
assert attr_name, '"{}" is not a prefix of "{}"'.format( base_node, base_name = node_to_attrname[expr.inputs[0]]
self.qualname, expr.outputs[0].qualname attr_name = "{}.{}".format(base_name, expr.name)
)
if attr_name in getattr_nodes_map: if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name:
base_node = getattr_nodes_map[attr_name] expected_qualname = base_node.qualname + "." + attr_name
logger.warning(
"{}.qualname expects {}, got {} actually. You can re-trace this "
"TracedModel to make the name correct.".format(
expr.outputs[0], expected_qualname, expr.outputs[0].qualname
)
)
expr.outputs[0]._qualname = expected_qualname
key = (base_node, attr_name)
node_to_attrname[expr.outputs[0]] = key
if key in getattr_nodes_map:
existed_node = getattr_nodes_map[key]
repl_node = expr.outputs[0] repl_node = expr.outputs[0]
for expr in repl_node.users: for expr in repl_node.users:
base_node.users.append(expr) existed_node.users.append(expr)
idx = expr.inputs.index(repl_node) idx = expr.inputs.index(repl_node)
expr.inputs[idx] = base_node expr.inputs[idx] = existed_node
repl_node.users = [] repl_node.users = []
else: else:
if attr_name != expr.name: if attr_name != expr.name:
...@@ -1054,7 +1066,7 @@ class InternalGraph: ...@@ -1054,7 +1066,7 @@ class InternalGraph:
expr.inputs[0].users.remove(expr) expr.inputs[0].users.remove(expr)
self.inputs[0].users.append(expr) self.inputs[0].users.append(expr)
expr.inputs[0] = self.inputs[0] expr.inputs[0] = self.inputs[0]
getattr_nodes_map[attr_name] = expr.outputs[0] getattr_nodes_map[key] = expr.outputs[0]
def compile(self): def compile(self):
r"""Delete unused expr.""" r"""Delete unused expr."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册