提交 feea43bc 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): associate name with node

GitOrigin-RevId: 8d9a59bade03b62aa8d4821dfb74cc828bf7312c
上级 a6fe7f7f
......@@ -69,6 +69,10 @@ def is_apply_def(expr):
return isinstance(expr, Apply)
def is_input(expr):
return isinstance(expr, Input)
class Expr:
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
......@@ -215,9 +219,11 @@ class Input(Expr):
@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
current_graph = active_module_tracer().current_scope()
expr = cls(*args, **kwargs)
out_node = expr.outputs[0]
active_module_tracer().current_scope()._add_input(out_node)
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._add_input(out_node)
return expr.outputs[0]
def __repr__(self):
......
......@@ -74,8 +74,7 @@ class Node:
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = graph._namespace.create_unique_name(new_name)
self._name = new_name
self._name = graph._namespace.create_unique_name(new_name, self)
@property
def qualname(self):
......
......@@ -68,6 +68,7 @@ from .expr import (
is_call_tensor_method,
is_constant,
is_getattr,
is_input,
)
from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import (
......@@ -342,13 +343,19 @@ class NameSpace:
self.qualname = qualname
self._used_names = {}
def create_unique_name(self, name: str) -> str:
def create_unique_name(self, name: str, node: Any = None) -> str:
assert isinstance(name, str), "The name must be a string"
if name in self._used_names and self._used_names[name] is node:
return name
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = "_{}".format(name)
while name in self._used_names or _is_builtin_name(name):
while (
name in self._used_names and self._used_names[name] is not None
) or _is_builtin_name(name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
......@@ -357,6 +364,10 @@ class NameSpace:
name = "{}_{}".format(base, int(num) + 1)
self._used_names.setdefault(name)
if node is not None:
self.associate_name_with_obj(name, node)
return name
def auto_naming_for_outputs(self, expr: Expr):
......@@ -384,7 +395,7 @@ class NameSpace:
qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name)
name = get_suffix_name(self.qualname, qualname)
_add_suffix = lambda x: x
elif is_constant(expr):
elif is_constant(expr) or is_input(expr):
name = (
expr.name if expr.name else "const_" + type(expr.value).__name__.lower()
)
......@@ -392,16 +403,25 @@ class NameSpace:
_add_suffix = lambda x: x
for node in expr.outputs:
if node._name == "" or node._name in self.used_names:
assert _add_suffix(name) == name or isinstance(node, TensorNode)
node._name = self.create_unique_name(_add_suffix(name))
cur_name = node._name if node._name else _add_suffix(name)
node._name = self.create_unique_name(cur_name, node)
if node._qualname == "":
node._qualname = qualname
assert get_suffix_name(self.qualname, qualname)
assert get_suffix_name(self.qualname, qualname) is not None
def merge(self, other: "NameSpace"):
self._used_names.update(other.used_names)
def associate_name_with_obj(self, name: str, node: Node):
assert name in self.used_names
assert self.used_names[name] is None, "The name(%s) is already in use" % (name)
self._used_names[name] = node
def unassociate_name_with_obj(self, node: Node):
assert node.name in self.used_names
assert self.used_names[node.name] is node
self._used_names[node.name] = None
@property
def used_names(self):
return self._used_names
......@@ -487,7 +507,7 @@ class InternalGraph:
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = self._namespace.create_unique_name(new_name)
new_name = self._namespace.create_unique_name(new_name, self)
self._name = new_name
@property
......@@ -726,6 +746,7 @@ class InternalGraph:
node = Input(
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
).outputs[0]
self._namespace.associate_name_with_obj(node.name, node)
node.shape = val.shape
node.dtype = val.dtype
return node
......@@ -764,9 +785,11 @@ class InternalGraph:
assert moudle._is_top, "add_input_node only supports top graph"
def create_node(name=None):
name = self._namespace.create_unique_name(name)
node = Input(
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
).outputs[0]
self._namespace.associate_name_with_obj(node.name, node)
node.shape = shape
node.dtype = dtype
return node
......@@ -774,7 +797,7 @@ class InternalGraph:
org_argdef = list(moudle.argdef_graph_map.keys())[0]
args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(self._namespace.create_unique_name(name))
formal_inp_node = create_node(name)
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
......@@ -1006,6 +1029,8 @@ class InternalGraph:
for n in expr.inputs:
n.users.remove(expr)
self._exprs.remove(expr)
for n in expr.outputs:
self._namespace.unassociate_name_with_obj(n)
def _reset_ids(self):
for total_expr_id, expr in enumerate(self.exprs()):
......@@ -1014,6 +1039,11 @@ class InternalGraph:
node._id = total_node_id
self._total_ids = (total_node_id + 1, total_expr_id + 1)
def _re_associate_name(self):
self._namespace.used_names.clear()
for node in self.nodes(False):
node._name = self._namespace.create_unique_name(node.name, node)
def interpret(self, *inputs):
node2value = {}
end_nodes_set = set(self._end_point)
......@@ -1108,6 +1138,8 @@ class InternalGraph:
if n._qualname:
qualname = "{}.{}".format(qualname, n._qualname)
n._qualname = qualname
self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name()
def _get_meth_name(obj, func):
......@@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin):
continue
for g in mod.argdef_graph_map.values():
replace_qualname(g)
g._namespace.qualname = g.qualname
for n in g.nodes(False):
replace_qualname(n)
else:
......@@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin):
name=parent_graph._namespace.create_unique_name(module_qualname),
qualname=module_qualname,
)
parent_graph._namespace.associate_name_with_obj(self._body.name, self._body)
active_module_tracer().push_scope(self._body)
# rebind self to new input node
......@@ -1552,6 +1586,7 @@ class _expr_iter:
def __init__(self, graph: InternalGraph, recursive: bool = True):
self.graph = graph
self.recursive = recursive
self._visited_graph = set()
def __iter__(self):
for inp_node in self.graph.inputs:
......@@ -1559,8 +1594,13 @@ class _expr_iter:
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
if self.recursive and expr.graph is not None:
if (
self.recursive
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
yield from expr.graph.exprs(self.recursive)
self._visited_graph.add(id(expr.graph))
else:
yield expr
......@@ -1570,12 +1610,11 @@ class _node_iter:
nodes = []
node_ids = set()
for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs:
if id(n) in node_ids:
continue
for n in expr.outputs:
assert id(n) not in node_ids
nodes.append(n)
node_ids.add(id(n))
self.nodes = list(sorted(nodes, key=lambda x: x._id))
self.nodes = nodes
def __iter__(self):
for node in self.nodes:
......@@ -2076,10 +2115,12 @@ class TracedModule(Module):
if parent_graph is not None:
for node in expr.outputs:
if node in rename_blacklist:
continue
name = "{}_{}".format(prefix_name, node._name)
node._name = parent_graph._namespace.create_unique_name(name)
name = node._name
if node not in rename_blacklist:
name = "{}_{}".format(prefix_name, name)
node._name = parent_graph._namespace.create_unique_name(
name, node
)
exprs.append(expr)
......@@ -2092,6 +2133,7 @@ class TracedModule(Module):
new_module.graph._exprs = _flatten_subgraph(
None, new_module.graph, None, new_module
)
new_module.graph._re_associate_name()
new_module.graph.compile()
new_module.graph._reset_ids()
return new_module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册