提交 feea43bc 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): associate name with node

GitOrigin-RevId: 8d9a59bade03b62aa8d4821dfb74cc828bf7312c
上级 a6fe7f7f
...@@ -69,6 +69,10 @@ def is_apply_def(expr): ...@@ -69,6 +69,10 @@ def is_apply_def(expr):
return isinstance(expr, Apply) return isinstance(expr, Apply)
def is_input(expr):
return isinstance(expr, Input)
class Expr: class Expr:
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
...@@ -215,9 +219,11 @@ class Input(Expr): ...@@ -215,9 +219,11 @@ class Input(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None assert active_module_tracer() is not None
current_graph = active_module_tracer().current_scope()
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
out_node = expr.outputs[0] out_node = expr.outputs[0]
active_module_tracer().current_scope()._add_input(out_node) current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._add_input(out_node)
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
......
...@@ -74,8 +74,7 @@ class Node: ...@@ -74,8 +74,7 @@ class Node:
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
new_name = graph._namespace.create_unique_name(new_name) self._name = graph._namespace.create_unique_name(new_name, self)
self._name = new_name
@property @property
def qualname(self): def qualname(self):
......
...@@ -68,6 +68,7 @@ from .expr import ( ...@@ -68,6 +68,7 @@ from .expr import (
is_call_tensor_method, is_call_tensor_method,
is_constant, is_constant,
is_getattr, is_getattr,
is_input,
) )
from .fake_quant import FakeQuantize as TM_FakeQuant from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import ( from .module_tracer import (
...@@ -342,13 +343,19 @@ class NameSpace: ...@@ -342,13 +343,19 @@ class NameSpace:
self.qualname = qualname self.qualname = qualname
self._used_names = {} self._used_names = {}
def create_unique_name(self, name: str) -> str: def create_unique_name(self, name: str, node: Any = None) -> str:
assert isinstance(name, str), "The name must be a string" assert isinstance(name, str), "The name must be a string"
if name in self._used_names and self._used_names[name] is node:
return name
name = re.sub("[^0-9a-zA-Z_]+", "_", name) name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit(): if name[0].isdigit():
name = "_{}".format(name) name = "_{}".format(name)
while name in self._used_names or _is_builtin_name(name): while (
name in self._used_names and self._used_names[name] is not None
) or _is_builtin_name(name):
match = re.match(r"(.*)_(\d+)$", name) match = re.match(r"(.*)_(\d+)$", name)
if match is None: if match is None:
name = name + "_1" name = name + "_1"
...@@ -357,6 +364,10 @@ class NameSpace: ...@@ -357,6 +364,10 @@ class NameSpace:
name = "{}_{}".format(base, int(num) + 1) name = "{}_{}".format(base, int(num) + 1)
self._used_names.setdefault(name) self._used_names.setdefault(name)
if node is not None:
self.associate_name_with_obj(name, node)
return name return name
def auto_naming_for_outputs(self, expr: Expr): def auto_naming_for_outputs(self, expr: Expr):
...@@ -384,7 +395,7 @@ class NameSpace: ...@@ -384,7 +395,7 @@ class NameSpace:
qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name)
name = get_suffix_name(self.qualname, qualname) name = get_suffix_name(self.qualname, qualname)
_add_suffix = lambda x: x _add_suffix = lambda x: x
elif is_constant(expr): elif is_constant(expr) or is_input(expr):
name = ( name = (
expr.name if expr.name else "const_" + type(expr.value).__name__.lower() expr.name if expr.name else "const_" + type(expr.value).__name__.lower()
) )
...@@ -392,16 +403,25 @@ class NameSpace: ...@@ -392,16 +403,25 @@ class NameSpace:
_add_suffix = lambda x: x _add_suffix = lambda x: x
for node in expr.outputs: for node in expr.outputs:
if node._name == "" or node._name in self.used_names: cur_name = node._name if node._name else _add_suffix(name)
assert _add_suffix(name) == name or isinstance(node, TensorNode) node._name = self.create_unique_name(cur_name, node)
node._name = self.create_unique_name(_add_suffix(name))
if node._qualname == "": if node._qualname == "":
node._qualname = qualname node._qualname = qualname
assert get_suffix_name(self.qualname, qualname) assert get_suffix_name(self.qualname, qualname) is not None
def merge(self, other: "NameSpace"): def merge(self, other: "NameSpace"):
self._used_names.update(other.used_names) self._used_names.update(other.used_names)
def associate_name_with_obj(self, name: str, node: Node):
assert name in self.used_names
assert self.used_names[name] is None, "The name(%s) is already in use" % (name)
self._used_names[name] = node
def unassociate_name_with_obj(self, node: Node):
assert node.name in self.used_names
assert self.used_names[node.name] is node
self._used_names[node.name] = None
@property @property
def used_names(self): def used_names(self):
return self._used_names return self._used_names
...@@ -487,7 +507,7 @@ class InternalGraph: ...@@ -487,7 +507,7 @@ class InternalGraph:
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
new_name = self._namespace.create_unique_name(new_name) new_name = self._namespace.create_unique_name(new_name, self)
self._name = new_name self._name = new_name
@property @property
...@@ -726,6 +746,7 @@ class InternalGraph: ...@@ -726,6 +746,7 @@ class InternalGraph:
node = Input( node = Input(
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
).outputs[0] ).outputs[0]
self._namespace.associate_name_with_obj(node.name, node)
node.shape = val.shape node.shape = val.shape
node.dtype = val.dtype node.dtype = val.dtype
return node return node
...@@ -764,9 +785,11 @@ class InternalGraph: ...@@ -764,9 +785,11 @@ class InternalGraph:
assert moudle._is_top, "add_input_node only supports top graph" assert moudle._is_top, "add_input_node only supports top graph"
def create_node(name=None): def create_node(name=None):
name = self._namespace.create_unique_name(name)
node = Input( node = Input(
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name)
).outputs[0] ).outputs[0]
self._namespace.associate_name_with_obj(node.name, node)
node.shape = shape node.shape = shape
node.dtype = dtype node.dtype = dtype
return node return node
...@@ -774,7 +797,7 @@ class InternalGraph: ...@@ -774,7 +797,7 @@ class InternalGraph:
org_argdef = list(moudle.argdef_graph_map.keys())[0] org_argdef = list(moudle.argdef_graph_map.keys())[0]
args, kwargs = org_argdef.unflatten(self._inputs) args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(self._namespace.create_unique_name(name)) formal_inp_node = create_node(name)
inputs, tree_def = tree_flatten( inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs), ((*args, formal_inp_node), kwargs),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
...@@ -1006,6 +1029,8 @@ class InternalGraph: ...@@ -1006,6 +1029,8 @@ class InternalGraph:
for n in expr.inputs: for n in expr.inputs:
n.users.remove(expr) n.users.remove(expr)
self._exprs.remove(expr) self._exprs.remove(expr)
for n in expr.outputs:
self._namespace.unassociate_name_with_obj(n)
def _reset_ids(self): def _reset_ids(self):
for total_expr_id, expr in enumerate(self.exprs()): for total_expr_id, expr in enumerate(self.exprs()):
...@@ -1014,6 +1039,11 @@ class InternalGraph: ...@@ -1014,6 +1039,11 @@ class InternalGraph:
node._id = total_node_id node._id = total_node_id
self._total_ids = (total_node_id + 1, total_expr_id + 1) self._total_ids = (total_node_id + 1, total_expr_id + 1)
def _re_associate_name(self):
self._namespace.used_names.clear()
for node in self.nodes(False):
node._name = self._namespace.create_unique_name(node.name, node)
def interpret(self, *inputs): def interpret(self, *inputs):
node2value = {} node2value = {}
end_nodes_set = set(self._end_point) end_nodes_set = set(self._end_point)
...@@ -1108,6 +1138,8 @@ class InternalGraph: ...@@ -1108,6 +1138,8 @@ class InternalGraph:
if n._qualname: if n._qualname:
qualname = "{}.{}".format(qualname, n._qualname) qualname = "{}.{}".format(qualname, n._qualname)
n._qualname = qualname n._qualname = qualname
self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name()
def _get_meth_name(obj, func): def _get_meth_name(obj, func):
...@@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin):
continue continue
for g in mod.argdef_graph_map.values(): for g in mod.argdef_graph_map.values():
replace_qualname(g) replace_qualname(g)
g._namespace.qualname = g.qualname
for n in g.nodes(False): for n in g.nodes(False):
replace_qualname(n) replace_qualname(n)
else: else:
...@@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin):
name=parent_graph._namespace.create_unique_name(module_qualname), name=parent_graph._namespace.create_unique_name(module_qualname),
qualname=module_qualname, qualname=module_qualname,
) )
parent_graph._namespace.associate_name_with_obj(self._body.name, self._body)
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
...@@ -1552,6 +1586,7 @@ class _expr_iter: ...@@ -1552,6 +1586,7 @@ class _expr_iter:
def __init__(self, graph: InternalGraph, recursive: bool = True): def __init__(self, graph: InternalGraph, recursive: bool = True):
self.graph = graph self.graph = graph
self.recursive = recursive self.recursive = recursive
self._visited_graph = set()
def __iter__(self): def __iter__(self):
for inp_node in self.graph.inputs: for inp_node in self.graph.inputs:
...@@ -1559,8 +1594,13 @@ class _expr_iter: ...@@ -1559,8 +1594,13 @@ class _expr_iter:
for expr in self.graph._exprs: for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr yield expr
if self.recursive and expr.graph is not None: if (
self.recursive
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
yield from expr.graph.exprs(self.recursive) yield from expr.graph.exprs(self.recursive)
self._visited_graph.add(id(expr.graph))
else: else:
yield expr yield expr
...@@ -1570,12 +1610,11 @@ class _node_iter: ...@@ -1570,12 +1610,11 @@ class _node_iter:
nodes = [] nodes = []
node_ids = set() node_ids = set()
for expr in graph.exprs(recursive): for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs: for n in expr.outputs:
if id(n) in node_ids: assert id(n) not in node_ids
continue
nodes.append(n) nodes.append(n)
node_ids.add(id(n)) node_ids.add(id(n))
self.nodes = list(sorted(nodes, key=lambda x: x._id)) self.nodes = nodes
def __iter__(self): def __iter__(self):
for node in self.nodes: for node in self.nodes:
...@@ -2076,10 +2115,12 @@ class TracedModule(Module): ...@@ -2076,10 +2115,12 @@ class TracedModule(Module):
if parent_graph is not None: if parent_graph is not None:
for node in expr.outputs: for node in expr.outputs:
if node in rename_blacklist: name = node._name
continue if node not in rename_blacklist:
name = "{}_{}".format(prefix_name, node._name) name = "{}_{}".format(prefix_name, name)
node._name = parent_graph._namespace.create_unique_name(name) node._name = parent_graph._namespace.create_unique_name(
name, node
)
exprs.append(expr) exprs.append(expr)
...@@ -2092,6 +2133,7 @@ class TracedModule(Module): ...@@ -2092,6 +2133,7 @@ class TracedModule(Module):
new_module.graph._exprs = _flatten_subgraph( new_module.graph._exprs = _flatten_subgraph(
None, new_module.graph, None, new_module None, new_module.graph, None, new_module
) )
new_module.graph._re_associate_name()
new_module.graph.compile() new_module.graph.compile()
new_module.graph._reset_ids() new_module.graph._reset_ids()
return new_module return new_module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册