From feea43bc7457d762dada0a0b52030ada34d6fd7b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 26 Oct 2021 21:24:02 +0800 Subject: [PATCH] fix(mge/traced_module): associate name with node GitOrigin-RevId: 8d9a59bade03b62aa8d4821dfb74cc828bf7312c --- .../python/megengine/traced_module/expr.py | 8 +- .../python/megengine/traced_module/node.py | 3 +- .../megengine/traced_module/traced_module.py | 78 ++++++++++++++----- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index af162ebf..23fe73d7 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -69,6 +69,10 @@ def is_apply_def(expr): return isinstance(expr, Apply) +def is_input(expr): + return isinstance(expr, Input) + + class Expr: r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, ``GetAttr``, ``Input``, ``Constant``) on ``Node``. @@ -215,9 +219,11 @@ class Input(Expr): @classmethod def make(cls, *args, **kwargs): assert active_module_tracer() is not None + current_graph = active_module_tracer().current_scope() expr = cls(*args, **kwargs) out_node = expr.outputs[0] - active_module_tracer().current_scope()._add_input(out_node) + current_graph._namespace.auto_naming_for_outputs(expr) + current_graph._add_input(out_node) return expr.outputs[0] def __repr__(self): diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index e812ff51..5aa571ad 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -74,8 +74,7 @@ class Node: "The name(%s) is already in use. Please try a different one again." % (new_name) ) - new_name = graph._namespace.create_unique_name(new_name) - self._name = new_name + self._name = graph._namespace.create_unique_name(new_name, self) @property def qualname(self): diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 4cfe079e..c3328e1b 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -68,6 +68,7 @@ from .expr import ( is_call_tensor_method, is_constant, is_getattr, + is_input, ) from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( @@ -342,13 +343,19 @@ class NameSpace: self.qualname = qualname self._used_names = {} - def create_unique_name(self, name: str) -> str: + def create_unique_name(self, name: str, node: Any = None) -> str: assert isinstance(name, str), "The name must be a string" + + if name in self._used_names and self._used_names[name] is node: + return name + name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = "_{}".format(name) - while name in self._used_names or _is_builtin_name(name): + while ( + name in self._used_names and self._used_names[name] is not None + ) or _is_builtin_name(name): match = re.match(r"(.*)_(\d+)$", name) if match is None: name = name + "_1" @@ -357,6 +364,10 @@ class NameSpace: name = "{}_{}".format(base, int(num) + 1) self._used_names.setdefault(name) + + if node is not None: + self.associate_name_with_obj(name, node) + return name def auto_naming_for_outputs(self, expr: Expr): @@ -384,7 +395,7 @@ class NameSpace: qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) name = get_suffix_name(self.qualname, qualname) _add_suffix = lambda x: x - elif is_constant(expr): + elif is_constant(expr) or is_input(expr): name = ( expr.name if expr.name else "const_" + type(expr.value).__name__.lower() ) @@ -392,16 +403,25 @@ class NameSpace: _add_suffix = lambda x: x for node in expr.outputs: - if node._name == "" or node._name in self.used_names: - assert _add_suffix(name) == name or isinstance(node, TensorNode) - node._name = self.create_unique_name(_add_suffix(name)) + cur_name = node._name if node._name else _add_suffix(name) + node._name = self.create_unique_name(cur_name, node) if node._qualname == "": node._qualname = qualname - assert get_suffix_name(self.qualname, qualname) + assert get_suffix_name(self.qualname, qualname) is not None def merge(self, other: "NameSpace"): self._used_names.update(other.used_names) + def associate_name_with_obj(self, name: str, node: Node): + assert name in self.used_names + assert self.used_names[name] is None, "The name(%s) is already in use" % (name) + self._used_names[name] = node + + def unassociate_name_with_obj(self, node: Node): + assert node.name in self.used_names + assert self.used_names[node.name] is node + self._used_names[node.name] = None + @property def used_names(self): return self._used_names @@ -487,7 +507,7 @@ class InternalGraph: "The name(%s) is already in use. Please try a different one again." % (new_name) ) - new_name = self._namespace.create_unique_name(new_name) + new_name = self._namespace.create_unique_name(new_name, self) self._name = new_name @property @@ -726,6 +746,7 @@ class InternalGraph: node = Input( type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) ).outputs[0] + self._namespace.associate_name_with_obj(node.name, node) node.shape = val.shape node.dtype = val.dtype return node @@ -764,9 +785,11 @@ class InternalGraph: assert moudle._is_top, "add_input_node only supports top graph" def create_node(name=None): + name = self._namespace.create_unique_name(name) node = Input( type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) ).outputs[0] + self._namespace.associate_name_with_obj(node.name, node) node.shape = shape node.dtype = dtype return node @@ -774,7 +797,7 @@ class InternalGraph: org_argdef = list(moudle.argdef_graph_map.keys())[0] args, kwargs = org_argdef.unflatten(self._inputs) - formal_inp_node = create_node(self._namespace.create_unique_name(name)) + formal_inp_node = create_node(name) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), @@ -1006,6 +1029,8 @@ class InternalGraph: for n in expr.inputs: n.users.remove(expr) self._exprs.remove(expr) + for n in expr.outputs: + self._namespace.unassociate_name_with_obj(n) def _reset_ids(self): for total_expr_id, expr in enumerate(self.exprs()): @@ -1014,6 +1039,11 @@ class InternalGraph: node._id = total_node_id self._total_ids = (total_node_id + 1, total_expr_id + 1) + def _re_associate_name(self): + self._namespace.used_names.clear() + for node in self.nodes(False): + node._name = self._namespace.create_unique_name(node.name, node) + def interpret(self, *inputs): node2value = {} end_nodes_set = set(self._end_point) @@ -1108,6 +1138,8 @@ class InternalGraph: if n._qualname: qualname = "{}.{}".format(qualname, n._qualname) n._qualname = qualname + self._namespace = NameSpace(self._name, self._qualname) + self._re_associate_name() def _get_meth_name(obj, func): @@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin): continue for g in mod.argdef_graph_map.values(): replace_qualname(g) + g._namespace.qualname = g.qualname for n in g.nodes(False): replace_qualname(n) else: @@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin): name=parent_graph._namespace.create_unique_name(module_qualname), qualname=module_qualname, ) + parent_graph._namespace.associate_name_with_obj(self._body.name, self._body) active_module_tracer().push_scope(self._body) # rebind self to new input node @@ -1552,6 +1586,7 @@ class _expr_iter: def __init__(self, graph: InternalGraph, recursive: bool = True): self.graph = graph self.recursive = recursive + self._visited_graph = set() def __iter__(self): for inp_node in self.graph.inputs: @@ -1559,8 +1594,13 @@ class _expr_iter: for expr in self.graph._exprs: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr - if self.recursive and expr.graph is not None: + if ( + self.recursive + and expr.graph is not None + and id(expr.graph) not in self._visited_graph + ): yield from expr.graph.exprs(self.recursive) + self._visited_graph.add(id(expr.graph)) else: yield expr @@ -1570,12 +1610,11 @@ class _node_iter: nodes = [] node_ids = set() for expr in graph.exprs(recursive): - for n in expr.inputs + expr.outputs: - if id(n) in node_ids: - continue + for n in expr.outputs: + assert id(n) not in node_ids nodes.append(n) node_ids.add(id(n)) - self.nodes = list(sorted(nodes, key=lambda x: x._id)) + self.nodes = nodes def __iter__(self): for node in self.nodes: @@ -2076,10 +2115,12 @@ class TracedModule(Module): if parent_graph is not None: for node in expr.outputs: - if node in rename_blacklist: - continue - name = "{}_{}".format(prefix_name, node._name) - node._name = parent_graph._namespace.create_unique_name(name) + name = node._name + if node not in rename_blacklist: + name = "{}_{}".format(prefix_name, name) + node._name = parent_graph._namespace.create_unique_name( + name, node + ) exprs.append(expr) @@ -2092,6 +2133,7 @@ class TracedModule(Module): new_module.graph._exprs = _flatten_subgraph( None, new_module.graph, None, new_module ) + new_module.graph._re_associate_name() new_module.graph.compile() new_module.graph._reset_ids() return new_module -- GitLab