提交 2d20f937 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix merge GetAttr failure when qualname is incorrect

GitOrigin-RevId: 46241e5b4f02ee09dd361c1fe4ea1d85acf49ef9
上级 e6c271ae
......@@ -74,7 +74,7 @@ class Node:
r"""Set a new name to this Node."""
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._namespace.used_names, (
assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
......
......@@ -544,11 +544,11 @@ class InternalGraph:
graph = self.top_graph
assert graph is not None or mod._is_top, "The parent graph cannot be None."
if graph is not None:
assert new_name not in self._namespace.used_names, (
assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = self._namespace.create_unique_name(new_name, self)
new_name = graph._namespace.create_unique_name(new_name, self)
self._name = new_name
@property
......@@ -1032,21 +1032,33 @@ class InternalGraph:
n.inputs[idx] = repl_node
def _merge_getattr_expr(self):
getattr_nodes_map = dict()
for expr in self._exprs:
if not isinstance(expr, GetAttr):
continue
attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname)
assert attr_name, '"{}" is not a prefix of "{}"'.format(
self.qualname, expr.outputs[0].qualname
)
if attr_name in getattr_nodes_map:
base_node = getattr_nodes_map[attr_name]
getattr_nodes_map = dict() # Dcit[(Node, str), Node]
node_to_attrname = dict() # Dict[Node, (Node, Str)]
for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs):
base_node, attr_name = expr.inputs[0], expr.name
if expr.inputs[0] in node_to_attrname:
base_node, base_name = node_to_attrname[expr.inputs[0]]
attr_name = "{}.{}".format(base_name, expr.name)
if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name:
expected_qualname = base_node.qualname + "." + attr_name
logger.warning(
"{}.qualname expects {}, got {} actually. You can re-trace this "
"TracedModel to make the name correct.".format(
expr.outputs[0], expected_qualname, expr.outputs[0].qualname
)
)
expr.outputs[0]._qualname = expected_qualname
key = (base_node, attr_name)
node_to_attrname[expr.outputs[0]] = key
if key in getattr_nodes_map:
existed_node = getattr_nodes_map[key]
repl_node = expr.outputs[0]
for expr in repl_node.users:
base_node.users.append(expr)
existed_node.users.append(expr)
idx = expr.inputs.index(repl_node)
expr.inputs[idx] = base_node
expr.inputs[idx] = existed_node
repl_node.users = []
else:
if attr_name != expr.name:
......@@ -1054,7 +1066,7 @@ class InternalGraph:
expr.inputs[0].users.remove(expr)
self.inputs[0].users.append(expr)
expr.inputs[0] = self.inputs[0]
getattr_nodes_map[attr_name] = expr.outputs[0]
getattr_nodes_map[key] = expr.outputs[0]
def compile(self):
r"""Delete unused expr."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册