diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 6ff9da2397693bd4cb9fa72c12cabb98b7f77f88..af162ebf590141f4d1a0fda9f6f815ccc5746d60 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -11,7 +11,7 @@ import collections import copy import inspect import re -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional, Union from ..core._imperative_rt import OpDef from ..core._imperative_rt.core2 import Tensor as RawTensor @@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str): return s +def get_suffix_name(prefix: str, name: str): + if prefix == name: + return "" + matchd = re.compile("^%s\.(.*)" % prefix).match(name) + if matchd is None: + return None + return matchd.group(1) + + +def is_call_module(expr): + return ( + isinstance(expr, CallMethod) + and isinstance(expr.inputs[0], ModuleNode) + and expr.method == "__call__" + ) + + +def is_call_tensor_method(expr): + return isinstance(expr, CallMethod) and not is_call_module(expr) + + +def is_call_function(expr): + return isinstance(expr, CallFunction) + + +def is_constant(expr): + return isinstance(expr, Constant) + + +def is_getattr(expr): + return isinstance(expr, GetAttr) + + +def is_apply_def(expr): + return isinstance(expr, Apply) + + class Expr: r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, ``GetAttr``, ``Input``, ``Constant``) on ``Node``. @@ -76,50 +113,19 @@ class Expr: self.const_val.append((idx, val)) def add_outputs(self, outputs): + assert active_module_tracer() is not None self.outputs = [] - if outputs is not None: - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - name = None - orig_name = None - if isinstance(self, CallMethod): - name = self.inputs[0]._name - orig_name = self.inputs[0]._orig_name - assert isinstance(name, str), "The name of ({}) must be a str".format( - self.inputs[0] - ) - assert isinstance( - orig_name, str - ), "The orig_name of ({}) must be a str".format(self.inputs[0]) - name = rstrip(name, "_out") - if self.method == "__call__": - name += "_out" - orig_name += "_out" - else: - strip_method = self.method.strip("_") - name = "%s_out" % strip_method - orig_name = name - elif isinstance(self, CallFunction): - name = self.func.__name__ + "_out" - elif isinstance(self, Apply): - name = str(self.opdef).lower() + "_out" - - for i in outputs: - assert isinstance(i, RawTensor), "The output must be a Tensor" - o_name = ( - active_module_tracer().current_scope()._create_unique_name(name) - ) - self.outputs.append( - NodeMixin.get_wrapped_type(i)( - expr=self, - name=o_name, - orig_name=orig_name if orig_name else o_name, - ) - ) - - for i, node in zip(outputs, self.outputs,): - NodeMixin.wrap_safe(i, node) + if outputs is None: + return + current_graph = active_module_tracer().current_scope() + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + for i in outputs: + assert isinstance(i, RawTensor), "The output must be a Tensor" + node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",) + NodeMixin.wrap_safe(i, node) + self.outputs.append(node) + current_graph._namespace.auto_naming_for_outputs(self) def unflatten_args(self, inputs): if self.arg_def is not None: @@ -152,9 +158,7 @@ class Expr: ), "({}) must be generated before ({})".format(repl_node, self) idx = self.inputs.index(node) self.inputs[idx] = repl_node - user_idx = node.users.index(self) - assert user_idx >= 0 - node.users.pop(user_idx) + node.users.remove(self) repl_node.users.append(self) @property @@ -197,26 +201,23 @@ class Input(Expr): r"""A fake Expr which is used to mark the input of graph.""" name = None - def __init__(self, name=None, type=None, orig_name=None): + def __init__(self, type: List[Node], name: str = "args", qualname: str = ""): super().__init__() + assert type in [ModuleNode, TensorNode] + assert name and qualname self.inputs = [] node_cls = type if type else Node - if orig_name is None: - orig_name = name self.outputs = [ - node_cls(self, name=name, orig_name=orig_name), + node_cls(self, name=name, qualname=qualname), ] self.name = name @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None expr = cls(*args, **kwargs) - oup_node = expr.outputs[0] - name = ( - active_module_tracer().current_scope()._create_unique_name(oup_node._name) - ) - oup_node._name = name - active_module_tracer().current_scope()._add_input(oup_node) + out_node = expr.outputs[0] + active_module_tracer().current_scope()._add_input(out_node) return expr.outputs[0] def __repr__(self): @@ -230,34 +231,41 @@ class GetAttr(Expr): name = None r"""name: the qualified name of the attribute to be retrieved.""" - def __init__(self, module, name, type=None, orig_name=None): + def __init__( + self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "", + ): super().__init__() assert isinstance(module, ModuleNode) + assert type in [TensorNode, ModuleNode] self.inputs = [ module, ] module.users.append(self) - self.name = name - node_cls = type if type else Node + self.name = attr_name self.outputs = [ - node_cls(self, name=name, orig_name=orig_name), + type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)), ] @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None + current_graph = active_module_tracer().current_scope() expr = cls(*args, **kwargs) - module = expr.inputs[0] - oup_name = expr.name - while module._name != "self": - oup_name = module._name + "_" + oup_name - module = module.expr.inputs[0] - oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) - expr.outputs[0]._name = oup_name - active_module_tracer().current_scope()._insert(expr) + current_graph._namespace.auto_naming_for_outputs(expr) + current_graph._insert(expr) return expr.outputs[0] def interpret(self, *inputs): - return (getattr(inputs[0], self.name),) + mod = inputs[0] + module_path, _, name = self.name.rpartition(".") + if module_path == "": + return (getattr(mod, name),) + module_names = module_path.split(".") + for item in module_names: + mod = getattr(mod, item) + if not isinstance(mod, Module): + raise AttributeError("`{}` is not an Module".format(item)) + return (getattr(mod, name),) def __repr__(self): out_type = "Tensor" @@ -297,6 +305,7 @@ class CallMethod(Expr): @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None expr = cls(*args, **kwargs) active_module_tracer().current_scope()._insert(expr) return expr @@ -362,6 +371,7 @@ class Apply(Expr): @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None expr = cls(*args, **kwargs) active_module_tracer().current_scope()._insert(expr) return expr @@ -435,6 +445,7 @@ class CallFunction(Expr): @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None expr = cls(*args, **kwargs) active_module_tracer().current_scope()._insert(expr) return expr @@ -474,7 +485,7 @@ class Constant(Expr): # TODO: constant cache to reduce the size of dumped model _constant_cache = {} - def __init__(self, c, name=None): + def __init__(self, c, name: str = "", qualname: str = ""): super().__init__() assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): @@ -484,31 +495,16 @@ class Constant(Expr): self.inputs = [] node_cls = NodeMixin.get_wrapped_type(c) self.outputs = [ - node_cls(self, name=name, orig_name=name), + node_cls(self, name=name, qualname=qualname), ] - self.outputs[0]._name = name if name else "const_" + str(self._id) @classmethod def make(cls, *args, **kwargs): + assert active_module_tracer() is not None expr = cls(*args, **kwargs) - name = "const_module" if isinstance(expr.value, Module) else "const_tensor" - full_name = name - if ( - isinstance(expr.value, RawTensor) - and id(expr.value) in active_module_tracer().id2name - ): - full_name = active_module_tracer().id2name[id(expr.value)] - scope_name = active_module_tracer().current_scope()._module_name - if full_name and scope_name: - full_name = ("self." + full_name)[len(scope_name) + 1 :] - else: - full_name = name - else: - full_name = name - name = active_module_tracer().current_scope()._create_unique_name(full_name) - expr.outputs[0]._name = name - expr.outputs[0]._orig_name = full_name - active_module_tracer().current_scope()._insert(expr) + current_graph = active_module_tracer().current_scope() + current_graph._namespace.auto_naming_for_outputs(expr) + current_graph._insert(expr) return expr.outputs[0] def interpret(self, *inputs): diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 96b62a6e1ceb70befd3ade85f3b92925192caa97..59eb5531719494a9052b17f9dd2cf856c2d800f4 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -128,10 +128,9 @@ class module_tracer: _active_scopes = None - def __init__(self, wrap_fn, id2name): + def __init__(self, wrap_fn): self._active_scopes = [] self.patcher = Patcher(wrap_fn) - self.id2name = id2name @classmethod def register_as_builtin(cls, mod): diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 67a2954752bb234b86cc856748f4d384f01b8aa9..e812ff5146d232d679dc5c48626c0167bbace765 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -29,17 +29,15 @@ class Node: __total_id = 0 # type: int _id = None # type: int _top_graph = None # type: weakref.ReferenceType - _name = None # type: str - _orig_name = None # type: str _format_spec = "" # type: str - def __init__(self, expr, name: str, orig_name: str): + def __init__(self, expr, name: str, qualname: str): self.expr = expr self.users = [] # List[Expr] self._id = Node.__total_id Node.__total_id += 1 self._name = name - self._orig_name = orig_name + self._qualname = qualname self.actual_node = [] # type: List[Node] def __repr__(self): @@ -54,21 +52,10 @@ class Node: name = "" if format_spec in ["i", "p", "ip", "pi"]: if "p" in format_spec: - graph = self.top_graph - prefix_name = "" - if graph is not None: - prefix_name = graph._name - if graph._prefix_name: - prefix_name = "{}_{}".format( - graph._prefix_name, prefix_name.lstrip("_") - ) - if name: - name = "_" + name.lstrip("_") - name = "{}{}".format(prefix_name, name) + prefix_name = self.top_graph._name + name = "{}_{}".format(prefix_name, name) if "i" in format_spec: - if name: - name = "_" + name.lstrip("_") - name = "%{}{}".format(self._id, name) + name = "%{}_{}".format(self._id, name) return name else: return name if name else ("%d" % self._id) @@ -80,15 +67,62 @@ class Node: @name.setter def name(self, new_name: str): + r"""Set a new name to this Node.""" graph = self.top_graph assert graph is not None, "The parent graph of this Node cannot be None." - assert new_name not in graph._used_names, ( + assert new_name not in graph._namespace.used_names, ( "The name(%s) is already in use. Please try a different one again." % (new_name) ) - new_name = graph._create_unique_name(new_name) + new_name = graph._namespace.create_unique_name(new_name) self._name = new_name - self._orig_name = new_name + + @property + def qualname(self): + r"""Get the `qualname` of this Node. The `qualname` can be used to get the + submodule from the traced Module or Module. + + Example: + .. code-block:: + + import megengine.module as M + import megengine.functional as F + import megengine.traced_module as tm + import megengine as mge + + class block(M.Module): + def __init__(self): + super().__init__() + self.param = mge.Tensor([1.]) + self.relu = M.ReLU() + + def forward(self, x): + x = x + self.param + return self.relu(F.relu(x)) + + class module(M.Module): + def __init__(self): + super().__init__() + self.block = block() + + def forward(self, x): + x = self.block(x) + return x + + net = module() + traced_net = tm.trace_module(net, mge.Tensor([0.])) + traced_net = traced_net.flatten() + out_node = traced_net.graph.outputs[0] + + # qualname : "module.block.relu.[out]" + qualname = out_node.qualname + # qualname : "block.relu" + qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0] + + assert qualname in list(map(lambda x: x[0], net.named_modules())) + assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) + """ + return self._qualname @property def top_graph(self): @@ -120,8 +154,8 @@ class ModuleNode(Node): r"""The type of the Module correspending to the ModuleNode.""" _owner = None # type: weakref.ReferenceType - def __init__(self, expr, name: str = None, orig_name: str = None): - super().__init__(expr, name, orig_name) + def __init__(self, expr, name: str = None, qualname: str = None): + super().__init__(expr, name, qualname) def __getstate__(self): return { @@ -129,10 +163,15 @@ class ModuleNode(Node): "users": self.users, "_id": self._id, "_name": self._name, - "_orig_name": self._orig_name, + "_qualname": self._qualname, "module_type": self.module_type, } + def __setstate__(self, state): + if "_orig_name" in state: + state["_qualname"] = state.pop("_orig_name") + self.__dict__.update(state) + @property def owner(self): r"""Get the ``Module`` corresponding to this ``ModuleNode``. @@ -161,9 +200,21 @@ class TensorNode(Node): "_dtype": self._dtype, "_device": self._device, "_name": self._name, - "_orig_name": self._orig_name, + "_qualname": self._qualname, } + def __setstate__(self, state): + if "_orig_name" in state: + qualname = state.pop("_orig_name") + modulepath, comma, qualname = qualname.rpartition(".") + expr_name = state["expr"].__class__.__name__ + if expr_name not in ["GetAttr"]: + qualname = "[{}]".format(qualname) + if comma: + qualname = "{}.{}".format(modulepath, qualname) + state["_qualname"] = qualname + self.__dict__.update(state) + @property def shape(self): r"""Get the shape of this Node.""" diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index f05d5ea6e2bde285dcb353efeb049bde830a91ff..4cfe079e927c700064c96f329d598479d6620e02 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -8,7 +8,6 @@ import builtins import collections import copy -import ctypes import fnmatch import functools import inspect @@ -31,8 +30,6 @@ from typing import ( Union, ) -from megengine import tensor - from .. import functional as F from .. import get_logger from .. import module as M @@ -43,7 +40,6 @@ from ..core._imperative_rt.core2 import ( unset_module_tracing, ) from ..core._trace_option import set_symbolic_shape -from ..core.tensor.array_method import ArrayMethodMixin from ..module import Module from ..module.qat import QATModule from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize @@ -57,7 +53,22 @@ from ..quantization.observer import ( SyncMinMaxObserver, ) from ..tensor import Tensor -from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input +from .expr import ( + Apply, + CallFunction, + CallMethod, + Constant, + Expr, + GetAttr, + Input, + get_suffix_name, + is_apply_def, + is_call_function, + is_call_module, + is_call_tensor_method, + is_constant, + is_getattr, +) from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( PatchedFn, @@ -192,36 +203,24 @@ def _wrap_mnode_getattr(orig_getattr): @functools.wraps(orig_getattr) def wraped_fn(self, name): obj = self.owner + current_graph = active_module_tracer().current_scope() if self.top_graph is not None: - active_module_tracer().current_scope()._add_input(self) + current_graph._add_input(self) attr = getattr(obj, name) node = attr - full_name = None - if id(attr) in active_module_tracer().id2name: - full_name = active_module_tracer().id2name[id(attr)] - if not isinstance(attr, TracedModuleBuilder): if isinstance(attr, Module): attr = TracedModuleBuilder(attr) setattr(obj, name, attr) - active_module_tracer().id2name[id(attr)] = full_name if isinstance(attr, (NodeMixin, RawTensor)): - if full_name: - scope_name = active_module_tracer().current_scope()._module_name - if scope_name: - full_name = full_name[len(scope_name) + 1 :] - else: - full_name = name - else: - full_name = name NodeMixin.wrap( attr, lambda: GetAttr.make( self, - name, type=NodeMixin.get_wrapped_type(attr), - orig_name=full_name, + attr_name=name, + name="", ), ) if isinstance(attr, (NodeMixin, RawTensor)): @@ -245,16 +244,6 @@ def _wrap_mnode_call(orig_call): return wraped_fn -def _init_id2name(mod: Module, prefix: str = ""): - id2name = { - id(m): "%s.%s" % (prefix, key) - for key, m in chain( - mod.named_modules(), mod.named_parameters(), mod.named_buffers() - ) - } - return id2name - - class _InsertExprs: def __init__(self, graph, expr: Optional[Expr] = None): self.graph = graph @@ -262,10 +251,8 @@ class _InsertExprs: graph = graph.top_graph assert graph.inputs[0].owner._is_top self.root_graph = graph - self.global_scope = InternalGraph( - graph._name, graph._prefix_name, graph._module_name - ) - self.global_scope._used_names.update(graph._used_names) + self.global_scope = InternalGraph(self.graph._name, self.graph._qualname) + self.global_scope._namespace.merge(self.graph._namespace) self.expr = expr self._tensor_method_patch = None @@ -277,10 +264,8 @@ class _InsertExprs: set_module_tracing() _set_convert_node_flag(True) assert active_module_tracer() is None - module = self.graph.inputs[0].owner - _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x)) set_active_module_tracer( - module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name)) + module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) ) active_module_tracer().patcher.__enter__() for cls, name, func in [ @@ -296,9 +281,10 @@ class _InsertExprs: if va is not None: return False set_symbolic_shape(self.use_sym_shape) - unset_module_tracing() active_module_tracer().patcher.__exit__(ty, va, tr) _set_convert_node_flag(False) + set_active_module_tracer(None) + unset_module_tracing() while self._tensor_method_patch: pf = self._tensor_method_patch.pop() @@ -310,13 +296,14 @@ class _InsertExprs: name = mod._name if isinstance(mod, TracedModuleBuilder): mod = mod.build() - if hasattr(mod, "graph"): - for node in mod.graph.nodes(): - node.value = None + if hasattr(mod, "argdef_graph_map"): + for g in mod.argdef_graph_map.values(): + for n in g.nodes(False): + if isinstance(n, TensorNode): + n.value = None setattr(parent, name, mod) - set_active_module_tracer(None) - for node in self.global_scope.nodes(): + for node in self.global_scope.nodes(False): node.value = None extra_inp_nodes = set(self.global_scope.inputs) @@ -339,24 +326,87 @@ class _InsertExprs: if insert_index < max_inp_expr_idx: insert_index = max_inp_expr_idx - anchor_index = insert_index - 1 - if anchor_index >= 0: - logger.info( - "The new expr will be inserted after ( {} )".format( - self.graph._exprs[anchor_index] - ) - ) - for expr in self.global_scope._exprs: self.graph._exprs.insert(insert_index, expr) insert_index += 1 - self.graph._used_names.update(self.global_scope._used_names) + self.graph._namespace.merge(self.global_scope._namespace) self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id()) self.root_graph.inputs[0].owner._update_ref() return True +class NameSpace: + def __init__(self, name, qualname): + self.name = name + self.qualname = qualname + self._used_names = {} + + def create_unique_name(self, name: str) -> str: + assert isinstance(name, str), "The name must be a string" + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = "_{}".format(name) + + while name in self._used_names or _is_builtin_name(name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = "{}_{}".format(base, int(num) + 1) + + self._used_names.setdefault(name) + return name + + def auto_naming_for_outputs(self, expr: Expr): + _add_suffix = lambda x: x + "_out" + if is_call_module(expr): + call_node = expr.inputs[0] + qualname = "%s.[out]" % (call_node.qualname) + name = call_node.name + elif is_call_tensor_method(expr): + name = expr.method.strip("_") + qualname = "{}.[{}]".format( + self.qualname, self.create_unique_name("method_%s" % (name)), + ) + elif is_call_function(expr): + name = expr.func.__name__ + qualname = "{}.[{}]".format( + self.qualname, self.create_unique_name("func_%s" % name), + ) + elif is_apply_def(expr): + name = str(expr.opdef).lower() + qualname = "{}.[{}]".format( + self.qualname, self.create_unique_name("def_%s" % name), + ) + elif is_getattr(expr): + qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) + name = get_suffix_name(self.qualname, qualname) + _add_suffix = lambda x: x + elif is_constant(expr): + name = ( + expr.name if expr.name else "const_" + type(expr.value).__name__.lower() + ) + qualname = "{}.[{}]".format(self.qualname, name) + _add_suffix = lambda x: x + + for node in expr.outputs: + if node._name == "" or node._name in self.used_names: + assert _add_suffix(name) == name or isinstance(node, TensorNode) + node._name = self.create_unique_name(_add_suffix(name)) + if node._qualname == "": + node._qualname = qualname + assert get_suffix_name(self.qualname, qualname) + + def merge(self, other: "NameSpace"): + self._used_names.update(other.used_names) + + @property + def used_names(self): + return self._used_names + + class InternalGraph: r"""``InternalGraph`` is the main data structure used in the TracedModule. It is used to represent the execution procedure of Module's forward method. @@ -407,37 +457,78 @@ class InternalGraph: _top_graph = None # type: InternalGraph _total_ids = None # type: List[int] - def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): + def __init__(self, name: str, qualname: str): self._exprs = [] self._inputs = [] self._outputs = [] self._watch_point = [] self._end_point = [] - self._used_names = {} + self._namespace = NameSpace(name, qualname) self._rst = collections.defaultdict(list) self._name = name - self._prefix_name = prefix_name - self._module_name = module_name + self._qualname = qualname def _insert(self, expr): self._exprs.append(expr) - def _create_unique_name(self, name: str) -> str: - assert isinstance(name, str), "The name must be a str" - name = re.sub("[^0-9a-zA-Z_]+", "_", name) - if name[0].isdigit(): - name = "_{}".format(name) + @property + def name(self) -> str: + r"""Get the name of this graph.""" + return self._name + + @name.setter + def name(self, new_name: str): + r"""Set a new name to this graph.""" + mod = self.inputs[0].owner + graph = self.top_graph + assert graph is not None or mod._is_top, "The parent graph cannot be None." + if graph is not None: + assert new_name not in self._namespace.used_names, ( + "The name(%s) is already in use. Please try a different one again." + % (new_name) + ) + new_name = self._namespace.create_unique_name(new_name) + self._name = new_name - while name in self._used_names or _is_builtin_name(name): - match = re.match(r"(.*)_(\d+)$", name) - if match is None: - name = name + "_1" - else: - base, num = match.group(1, 2) - name = "{}_{}".format(base, int(num) + 1) + @property + def qualname(self) -> str: + r"""Get the `qualname` of this graph. The `qualname` can be used to get the + submodule from the traced Module or Module. - self._used_names.setdefault(name) - return name + Example: + .. code-block:: + + import megengine.module as M + import megengine.traced_module as tm + import megengine as mge + + class block(M.Module): + def __init__(self): + super().__init__() + self.relu = M.ReLU() + + def forward(self, x): + return self.relu(x) + + class module(M.Module): + def __init__(self): + super().__init__() + self.block = block() + + def forward(self, x): + x = self.block(x) + return x + + net = module() + traced_net = tm.trace_module(net, mge.Tensor([0.])) + + qualname = traced_net.block.graph.qualname # qualname = "module.block" + qualname = qualname.split(".", 1)[-1] # qualname = "block" + + assert qualname in list(map(lambda x: x[0], net.named_modules())) + assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) + """ + return self._qualname @property def inputs(self) -> List[Node]: @@ -596,55 +687,6 @@ class InternalGraph: def _add_output(self, o): self._outputs.append(o) - def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""): - for node, repl_node in repl_dict.items(): - assert node in self._inputs or node in self._outputs - for i in node.users: - if i not in repl_node.users: - repl_node.users.append(i) - - for idx, i in enumerate(self._inputs): - if i in repl_dict: - self._inputs[idx] = repl_dict[i] - - for idx, o in enumerate(self._outputs): - if o in repl_dict: - repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name) - self._outputs[idx] = repl_dict[o] - - for expr in self._exprs: - - for idx, i in enumerate(expr.inputs): - assert isinstance( - i._name, str - ), "The node ({}) name must be a str".format(i) - if i in repl_dict: - expr.inputs[idx] = repl_dict[i] - elif isinstance(i, TensorNode) and prefix_name not in i._name: - if i.top_graph != active_module_tracer().current_scope(): - i._name = ( - active_module_tracer() - .current_scope() - ._create_unique_name(prefix_name + i._name.lstrip("_")) - ) - i._orig_name = "{}{}".format(module_name, i._orig_name) - - for idx, o in enumerate(expr.outputs): - assert isinstance( - o._name, str - ), "The node ({}) name must be a str".format(i) - if o in repl_dict: - expr.outputs[idx] = repl_dict[o] - expr.outputs[idx].expr = expr - elif isinstance(o, TensorNode) and prefix_name not in i._name: - if o.top_graph != active_module_tracer().current_scope(): - o._name = ( - active_module_tracer() - .current_scope() - ._create_unique_name(prefix_name + o._name.lstrip("_")) - ) - o._orig_name = "{}{}".format(module_name, o._orig_name) - def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: r"""Get the dependent Exprs of the ``nodes``. @@ -674,20 +716,16 @@ class InternalGraph: def reset_inputs(self, *args, **kwargs): forma_mnode = self.inputs[0] - actual_mnodes = forma_mnode.actual_node - call_nodes = [] - for n in actual_mnodes: - for c_expr in n.users: - if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": - call_nodes.append((c_expr, n)) - moudle = forma_mnode.owner - assert moudle._is_top, "reset_inputs only support the top-level graph" + assert moudle._is_top, "reset_inputs only supports top graph" inputs, tree_def = tree_flatten(((moudle, *args), kwargs)) def create_node(val: Tensor): - node = Input(type=TensorNode).outputs[0] + name = self._namespace.create_unique_name("args") + node = Input( + type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) + ).outputs[0] node.shape = val.shape node.dtype = val.dtype return node @@ -697,29 +735,11 @@ class InternalGraph: ] org_argdef = list(moudle.argdef_graph_map.keys())[0] - if call_nodes: - org_argdef = call_nodes[0][0].arg_def for v in inputs[1:]: assert isinstance(v, RawTensor) formal_node_inputs.append(create_node(v)) - actual_nodes = [] - for e, n in call_nodes: - e.arg_def = tree_def - actual_node_inputs = [ - n, - ] - for v in inputs[1:]: - actual_node_inputs.append(create_node(v)) - - for org_n in e.inputs: - org_n.users.pop(e) - - e.inputs[:] = actual_node_inputs - e.const_val = [] - actual_nodes.append(actual_node_inputs[1:]) - self._inputs[:] = formal_node_inputs moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) @@ -740,51 +760,27 @@ class InternalGraph: a suffix will be added to it. """ forma_mnode = self.inputs[0] - actual_mnodes = forma_mnode.actual_node - moudle = forma_mnode.owner - assert moudle._is_top, "add_input_node only support the top-level graph" - - call_nodes = [] - for n in actual_mnodes: - for c_expr in n.users: - if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": - call_nodes.append(c_expr) + assert moudle._is_top, "add_input_node only supports top graph" - def create_node(name=None, is_input: bool = True): - if is_input: - node = Input(type=TensorNode, name=name).outputs[0] - else: - node = TensorNode(expr=None, name=None) + def create_node(name=None): + node = Input( + type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) + ).outputs[0] node.shape = shape node.dtype = dtype return node org_argdef = list(moudle.argdef_graph_map.keys())[0] - if call_nodes: - org_argdef = call_nodes[0].arg_def - args, kwargs = org_argdef.unflatten(self._inputs) - formal_inp_node = create_node(self._create_unique_name(name), True) + formal_inp_node = create_node(self._namespace.create_unique_name(name)) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) self._inputs[:] = inputs[:] - actual_inp_nodes = [] - for e in call_nodes: - args, kwargs = e.unflatten_args(e.inputs) - args = args + (create_node(False),) - inputs, tree_def = tree_flatten( - (args, kwargs), - is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), - ) - e.inputs[:] = inputs[:] - e.arg_def = tree_def - actual_inp_nodes.append(args[-1]) - moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) return formal_inp_node @@ -841,39 +837,13 @@ class InternalGraph: outputs, is_leaf=lambda x: isinstance(x, TensorNode), ) forma_mnode = self.inputs[0] - moudle = forma_mnode.owner - assert moudle._is_top, "reset_outputs only support the top graph" - - actual_mnodes = forma_mnode.actual_node - call_nodes = [] - for n in actual_mnodes: - for c_expr in n.users: - if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": - call_nodes.append((c_expr)) - - def create_node(val: TensorNode, expr: Expr): - node = TensorNode(expr) - node.shape = val.shape - node.dtype = val.dtype - return node + assert moudle._is_top, "reset_outputs only supports top graph" tree_def = list(moudle.argdef_graph_map.keys())[0] - if call_nodes: - tree_def = call_nodes[0].arg_def - - actual_nodes = [] - for e in call_nodes: - actual_node_outputs = [] - for v in outputs: - actual_node_outputs.append(create_node(v, e)) - e.outputs[:] = actual_node_outputs - e.out_def = out_def - actual_nodes.append(actual_node_outputs) self._outputs[:] = outputs moudle.argdef_outdef_map[tree_def] = out_def - return actual_nodes def add_output_node(self, node: TensorNode): r"""Add an output node to the Graph. @@ -920,27 +890,10 @@ class InternalGraph: ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0)) """ forma_mnode = self.inputs[0] - moudle = forma_mnode.owner - assert moudle._is_top, "add_output_node only support the top graph" - - actual_mnodes = forma_mnode.actual_node - call_nodes = [] - - for n in actual_mnodes: - for c_expr in n.users: - if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": - call_nodes.append((c_expr)) - - def create_node(val: TensorNode, expr: Expr): - node = TensorNode(expr) - node.shape = val.shape - node.dtype = val.dtype - return node + assert moudle._is_top, "add_output_node only supports top graph" tree_def = list(moudle.argdef_graph_map.keys())[0] - if call_nodes: - tree_def = call_nodes[0].arg_def org_out_def = moudle.argdef_outdef_map[tree_def] org_outs = org_out_def.unflatten(self._outputs) @@ -948,22 +901,8 @@ class InternalGraph: (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode), ) self._outputs[:] = outputs - - actual_out_nodes = [] - for e in call_nodes: - actual_node = create_node(node, e) - org_outs = org_out_def.unflatten(e.outputs) - outputs, out_def = tree_flatten( - (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode), - ) - e.outputs[:] = outputs - e.out_def = out_def - actual_out_nodes.append(actual_node) - moudle.argdef_outdef_map[tree_def] = out_def - return actual_out_nodes - def insert_exprs(self, expr: Optional[Expr] = None): r"""Initialize the trace mode and insertion position. @@ -1029,8 +968,34 @@ class InternalGraph: idx = n.inputs.index(node) n.inputs[idx] = repl_node + def _merge_getattr_expr(self): + getattr_nodes_map = dict() + for expr in self._exprs: + if not isinstance(expr, GetAttr): + continue + attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname) + assert attr_name, '"{}" is not a prefix of "{}"'.format( + self.qualname, expr.outputs[0].qualname + ) + if attr_name in getattr_nodes_map: + base_node = getattr_nodes_map[attr_name] + repl_node = expr.outputs[0] + for expr in repl_node.users: + base_node.users.append(expr) + idx = expr.inputs.index(repl_node) + expr.inputs[idx] = base_node + repl_node.users = [] + else: + if attr_name != expr.name: + expr.name = attr_name + expr.inputs[0].users.remove(expr) + self.inputs[0].users.append(expr) + expr.inputs[0] = self.inputs[0] + getattr_nodes_map[attr_name] = expr.outputs[0] + def compile(self): r"""Delete unused expr.""" + self._merge_getattr_expr() dep_exprs = self.get_dep_exprs(self.outputs) i = 0 while i < len(self._exprs): @@ -1121,6 +1086,29 @@ class InternalGraph: state.pop("_top_graph") return state + def __setstate__(self, state): + old_version = False + if "_module_name" in state: + old_version = True + state["_qualname"] = state.pop("_module_name") + prefix_name = state.pop("_prefix_name") + if prefix_name: + state["_name"] = "{}_{}".format(prefix_name, state["_name"]) + + self.__dict__.update(state) + + if old_version: + for n in self.nodes(False): + qualname = self._qualname + if isinstance(n.expr, CallMethod) and isinstance( + n.expr.inputs[0], ModuleNode + ): + n._qualname = n.expr.inputs[0]._qualname + ".[out]" + continue + if n._qualname: + qualname = "{}.{}".format(qualname, n._qualname) + n._qualname = qualname + def _get_meth_name(obj, func): tp = obj if isinstance(obj, type) else type(obj) @@ -1158,7 +1146,8 @@ def _wrapped_function(orig_func): if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) inputs[1] = copy.copy(inputs[1]) - # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing. + # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, + # which will cause they have same _NodeMixin__node in tracing. NodeMixin.wrap_safe(inputs[1], node) args, kwargs = tree_def.unflatten(inputs) call_node = CallMethod.make(self, meth_name) @@ -1363,21 +1352,36 @@ class TracedModuleBuilder(NodeMixin): self._body = self._argdef_graph_map[tree_def] else: self._mod._is_top = False - self._body = self._mod.graph + self._body = self._mod.argdef_graph_map[tree_def] + module_qualname = NodeMixin.get(self).qualname + if module_qualname != self._body.qualname: + src_name, dst_name = self._body.qualname, module_qualname + + def replace_qualname(g): + attr_name = get_suffix_name(src_name, g.qualname) + if attr_name is not None: + g._qualname = ( + ("%s.%s" % (dst_name, attr_name)) + if attr_name + else dst_name + ) + assert get_suffix_name(dst_name, g.qualname) is not None + + for mod in self._mod.modules(): + if not hasattr(mod, "argdef_graph_map"): + continue + for g in mod.argdef_graph_map.values(): + replace_qualname(g) + for n in g.nodes(False): + replace_qualname(n) else: self_node = None orig_self = NodeMixin.get(self) - top_graph = active_module_tracer().current_scope() - graph_prefix_name = top_graph._name - if top_graph._prefix_name: - graph_prefix_name = "{}_{}".format( - top_graph._prefix_name, graph_prefix_name.lstrip("_") - ) - module_name = orig_self._orig_name - if top_graph._module_name: - module_name = "{}.{}".format(top_graph._module_name, module_name) + parent_graph = active_module_tracer().current_scope() + module_qualname = orig_self._qualname self._body = InternalGraph( - orig_self._name, prefix_name=graph_prefix_name, module_name=module_name + name=parent_graph._namespace.create_unique_name(module_qualname), + qualname=module_qualname, ) active_module_tracer().push_scope(self._body) # rebind self to new input node @@ -1390,8 +1394,13 @@ class TracedModuleBuilder(NodeMixin): self, self_node if self_node - else Input.make("self", NodeMixin.get_wrapped_type(self), ""), + else Input.make( + name="self", + qualname=module_qualname, + type=NodeMixin.get_wrapped_type(self), + ), ) + origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph index_args, index_kwargs = tree_def.unflatten( @@ -1416,7 +1425,9 @@ class TracedModuleBuilder(NodeMixin): NodeMixin.wrap( x, lambda: Input.make( - type=NodeMixin.get_wrapped_type(x), name=name + type=NodeMixin.get_wrapped_type(x), + name=name, + qualname="%s.[%s]" % (module_qualname, name), ), ) return x @@ -1430,13 +1441,18 @@ class TracedModuleBuilder(NodeMixin): getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) rst = type(self._mod).forward(*args, **kwargs) + if _convert_node_flag(): rst = _node_to_tensor(rst)[0][0] + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) + for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): + mark_constant(i) active_module_tracer().current_scope()._add_output(NodeMixin.get(i)) + NodeMixin.wrap_safe(self, orig_self) for arg, node in zip(inputs[1:], origin_inp_node): if node: @@ -1461,15 +1477,13 @@ class TracedModuleBuilder(NodeMixin): attr = getattr(type(self._mod), name).__get__(self, type(self)) else: attr = getattr(self._mod, name) - full_name = None + if ( isinstance(attr, FunctionType) and id(attr) in active_module_tracer().patcher.patched_fn_ids ): return active_module_tracer().patcher.wrap_fn(attr) - if id(attr) in active_module_tracer().id2name: - full_name = active_module_tracer().id2name[id(attr)] if isinstance(attr, (List, Dict)): unset_module_tracing() has_module, m_container = replace_container_with_module_container(attr) @@ -1477,31 +1491,24 @@ class TracedModuleBuilder(NodeMixin): attr = m_container if has_module and not m_container: raise ValueError( - "Can not trace the module that uses the same container to store Module and Non-Module objects " + "Can not trace the module that uses the same container to store" + " Module and Non-Module objects." ) set_module_tracing() + if isinstance(attr, Module): attr = TracedModuleBuilder(attr) if isinstance(attr, (Module, RawTensor)): setattr(self, name, attr) - active_module_tracer().id2name[id(attr)] = full_name - if full_name: - scope_name = active_module_tracer().current_scope()._module_name - if scope_name: - full_name = full_name[len(scope_name) + 1 :] - else: - full_name = name - else: - full_name = name NodeMixin.wrap( attr, lambda: GetAttr.make( NodeMixin.get(self), - name, type=NodeMixin.get_wrapped_type(attr), - orig_name=full_name, + attr_name=name, + name="", ), ) return attr @@ -1527,25 +1534,14 @@ class TracedModuleBuilder(NodeMixin): else: assert mod_attr is wrapped - full_name = None - if id(mod_attr) in active_module_tracer().id2name: - full_name = active_module_tracer().id2name[id(mod_attr)] - scope_name = active_module_tracer().current_scope()._module_name - if full_name and scope_name: - full_name = full_name[len(scope_name) + 1 :] - else: - full_name = name - else: - full_name = name - # assert not self._is_builtin if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( wrapped, lambda: GetAttr.make( NodeMixin.get(self), - name, type=NodeMixin.get_wrapped_type(wrapped), - orig_name=full_name, + attr_name=name, + name="", ), ) @@ -1731,9 +1727,7 @@ class NodeFilterName(NodeFilter): def __iter__(self): for i in self._iter: graph = i.top_graph - name = "{}_{}".format(graph._name, i._name.lstrip("_")) - if graph._prefix_name: - name = "{}_{}".format(graph._prefix_name, name.lstrip("_")) + name = "{}_{}".format(graph._name, i._name) if self.pattern == name or self._re.match(name): yield i @@ -1967,7 +1961,7 @@ class TracedModule(Module): if isinstance(expr, GetAttr) and isinstance( expr.outputs[0], ModuleNode ): - obj = getattr(node2obj[expr.inputs[0]], expr.name) + obj = expr.interpret(node2obj[expr.inputs[0]])[0] expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj if isinstance(expr, Constant) and isinstance( @@ -2001,47 +1995,29 @@ class TracedModule(Module): A new :class:`TracedModule`. """ new_module = copy.deepcopy(self) - assert active_module_tracer() is None - id2name = _init_id2name(new_module, "self") - set_active_module_tracer(module_tracer(lambda x: x, {})) - active_module_tracer().push_scope(new_module.graph) + + def _replace_inputs_and_outputs(expr: Expr, repl_dict: Dict[Node, Node]): + inputs, outputs = expr.inputs, expr.outputs + for i, node in enumerate(inputs): + if node in repl_dict: + inputs[i] = repl_dict[node] + for i, node in enumerate(outputs): + if node in repl_dict: + outputs[i] = repl_dict[node] + outputs[i].expr = expr def _flatten_subgraph( parent_graph: InternalGraph, graph: InternalGraph, + call: CallMethod, module: Module, - call=None, - prefix_name="", - module_name="", ): - if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_": - prefix_name += "_" - if isinstance(module_name, str) and module_name: - module_name += "." - if graph is None or module.is_qat: - assert not isinstance(module, TracedModule) or module.is_qat - const = Constant(module, id2name[id(module)]) - m_node = call.inputs[0] - if m_node.top_graph != active_module_tracer().current_scope(): - m_node._name = ( - active_module_tracer() - .current_scope() - ._create_unique_name(prefix_name) - ) - m_node._orig_name = id2name[id(module)][5:] - const.outputs[0] = m_node - const.outputs[0].expr = const - return [const, call] + repl_dict, node2obj, rename_blacklist = {}, {}, [] + if call is not None: graph = copy.deepcopy(graph) - exprs = [] - node2obj = {} - node2obj[graph._inputs[0]] = module - if call: node2obj[call.inputs[0]] = module - # replace inputs for submodule's exprx - if call: repl_dict = dict(zip(graph._inputs, call.inputs)) for ind, out in enumerate(graph.outputs): if isinstance(out.expr, Input): @@ -2058,60 +2034,65 @@ class TracedModule(Module): parent_graph._outputs[index] = repl_dict[out] continue repl_dict[out] = call.outputs[ind] + if isinstance(out, TensorNode): + call.outputs[ind]._qualname = out._qualname + + for node, repl_node in repl_dict.items(): + assert node in graph._inputs or node in graph._outputs + for i in node.users: + if i not in repl_node.users: + repl_node.users.append(i) - graph._replace_inputs_outputs(repl_dict, prefix_name, module_name) + rename_blacklist = list(chain(call.inputs, call.outputs)) + node2obj[graph._inputs[0]] = module + prefix_name = call.inputs[0]._name if call else "" + exprs = [] for expr in graph._exprs: + + if call is not None: + _replace_inputs_and_outputs(expr, repl_dict) + if isinstance(expr, GetAttr): - # replace GetAttr with Constant - if isinstance(expr.outputs[0], TensorNode): - const = Constant(getattr(node2obj[expr.inputs[0]], expr.name)) - const.outputs = expr.outputs - const.outputs[0].expr = const - exprs.append(const) - elif isinstance(expr.outputs[0], ModuleNode): - node2obj[expr.outputs[0]] = getattr( - node2obj[expr.inputs[0]], expr.name - ) + mnode = expr.inputs[0] + node2obj[expr.outputs[0]] = expr.interpret(node2obj[mnode])[0] - elif isinstance(expr, CallMethod): + if isinstance(expr, CallMethod): obj_node = expr.inputs[0] - if isinstance(obj_node, ModuleNode): - pre_expr = expr.inputs[0].expr - if isinstance(pre_expr, GetAttr): - (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]]) - expr_graph = ( - obj.argdef_graph_map[expr.arg_def] - if hasattr(obj, "argdef_graph_map") - else None - ) + if isinstance(obj_node, ModuleNode) and isinstance( + obj_node.expr, GetAttr + ): + obj = node2obj[obj_node] + expr_graph = ( + obj.argdef_graph_map[expr.arg_def] + if hasattr(obj, "argdef_graph_map") + else None + ) + if expr_graph is not None: exprs.extend( - _flatten_subgraph( - graph, - expr_graph, - obj, - expr, - prefix_name + obj_node._name.lstrip("_"), - module_name + obj_node._orig_name, - ) + _flatten_subgraph(graph, expr_graph, expr, obj) ) - else: - # module has been replaced. - assert isinstance(pre_expr, Constant) - exprs.append(expr) - else: - exprs.append(expr) - else: - exprs.append(expr) + continue + + if parent_graph is not None: + for node in expr.outputs: + if node in rename_blacklist: + continue + name = "{}_{}".format(prefix_name, node._name) + node._name = parent_graph._namespace.create_unique_name(name) + + exprs.append(expr) if call is not None: for i in call.inputs: i.users.remove(call) + return exprs - new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module) + new_module.graph._exprs = _flatten_subgraph( + None, new_module.graph, None, new_module + ) new_module.graph.compile() - set_active_module_tracer(None) new_module.graph._reset_ids() return new_module @@ -2208,25 +2189,31 @@ def trace_module( assert active_module_tracer() is None assert isinstance(mod, Module) try: + net_name = mod._name if mod._name else mod.__class__.__name__ use_sym_shape = set_symbolic_shape(True) set_module_tracing() - set_active_module_tracer( - module_tracer(_wrapped_function, _init_id2name(mod, "self")) - ) + set_active_module_tracer(module_tracer(_wrapped_function)) for cls in [Expr, Node]: cls._set_next_id(0) with active_module_tracer().patcher: - global_scope = InternalGraph(name="") + global_scope = InternalGraph(name="top", qualname=net_name) active_module_tracer().push_scope(global_scope) builder = TracedModuleBuilder(mod, True) - name = mod._name if mod._name else mod.__class__.__name__ - NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self")) + + NodeMixin.wrap_safe( + builder, Input.make(name="top", type=ModuleNode, qualname=net_name) + ) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " if isinstance(i, RawTensor): NodeMixin.wrap_safe( - i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) + i, + Input.make( + name="arg_{}".format(_), + type=NodeMixin.get_wrapped_type(i), + qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), + ), ) builder(*args, **kwargs) active_module_tracer().pop_scope() diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index edf48348ad4fdf7eee69b31a9c8bbe10286b43de..820987e3b3a278b5abe4c51707d99b883fb2c5ca 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pickle +from itertools import chain import numpy as np @@ -13,8 +14,8 @@ import megengine.functional as F import megengine.module as M from megengine.module.identity import Identity from megengine.traced_module import trace_module -from megengine.traced_module.expr import CallFunction, Expr, GetAttr -from megengine.traced_module.node import Node +from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input +from megengine.traced_module.node import ModuleNode, Node class IdentityMod(M.Module): @@ -85,6 +86,34 @@ def test_search(): relu_expr = graph.get_function_by_type(F.relu).as_unique() assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu + conv_node = graph.get_module_by_type(M.Conv2d).as_unique() + assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d + + add_expr = graph.get_method_by_type("__add__").as_unique() + assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__" + + conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique() + assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d + + +def test_producer_and_users(): + traced_module, *_ = _init_module() + + def _check(exprs): + for expr in exprs: + for n in chain(expr.inputs, expr.outputs): + if not isinstance(n.expr, Input): + assert n.expr in exprs + for e in n.users: + assert e in exprs + assert n in e.inputs + + for mod in traced_module.modules(): + if not hasattr(mod, "argdef_graph_map"): + continue + for g in mod.argdef_graph_map.values(): + _check(g._exprs) + def test_insert(): traced_module, x, expect = _init_block() @@ -97,6 +126,54 @@ def test_insert(): np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) +def test_insert_module(): + class Neg(M.Module): + def forward(self, x): + return F.neg(x) + + traced_module, x, expect = _init_block() + graph = traced_module.graph + relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] + self = graph.inputs[0] + setattr(traced_module, "neg", Neg()) + with graph.insert_exprs(): + neg_out = self.neg(relu_out) + graph.replace_node({relu_out: neg_out}) + graph.compile() + np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) + assert traced_module.neg.graph is not None + assert len(traced_module.neg.graph._exprs) == 1 + + +def test_add_input_and_output(): + traced_module, x, y = _init_module() + + data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data") + traced_module.graph.add_output_node(data_node) + + assert data_node.name == "data" + assert traced_module.graph.inputs[-1] == data_node + assert len(traced_module.graph.inputs) == 3 + assert len(traced_module.graph.outputs) == 2 + + y1, y2 = traced_module(x, x) + np.testing.assert_equal(y1.numpy(), y.numpy()) + np.testing.assert_equal(y2.numpy(), x.numpy()) + + y1, y2 = traced_module(x, y) + np.testing.assert_equal(y2.numpy(), y.numpy()) + + traced_module.graph.reset_outputs( + ({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1]) + ) + + out = traced_module(x, x) + assert isinstance(out, tuple) + assert isinstance(out[0], dict) + np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy()) + np.testing.assert_equal(out[1].numpy(), x.numpy()) + + def test_delete(): traced_module, x, expect = _init_block() graph = traced_module.graph @@ -117,8 +194,10 @@ def test_delete(): def test_flatten(): traced_module, x, expect = _init_module() traced_module = traced_module.flatten() - traced_module.graph.compile() - assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) + assert len(traced_module.graph._exprs) == 12 + np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) + + traced_module = traced_module.flatten() assert len(traced_module.graph._exprs) == 12 np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) @@ -128,7 +207,7 @@ def test_id_and_name(): _total_ids = traced_module.graph._total_ids node_ids = [n._id for n in traced_module.graph.nodes().as_list()] assert len(set(node_ids)) == len(node_ids) - assert max(node_ids) + 1 == len(node_ids) + assert max(node_ids) + 1 == _total_ids[0] expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] assert len(set(expr_ids)) == len(expr_ids) @@ -177,7 +256,7 @@ def test_id_and_name(): _check_name(flattened_module) -def test_set_name(): +def test_set_node_name(): traced_module, x, expect = _init_module() graph = traced_module.graph output_node = graph.outputs[0] @@ -190,6 +269,18 @@ def test_set_name(): np.testing.assert_equal(str(graph.outputs[0]), "output") +def test_set_graph_name(): + traced_module, x, expect = _init_module() + graph = traced_module.graph + output_node = graph.outputs[0] + + node_name = output_node.name + + graph.name = "Top" + node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique() + assert node is output_node + + def test_extra_block(): class PostProcess(M.Module): def forward(self, x): diff --git a/imperative/python/test/unit/traced_module/test_qat_module.py b/imperative/python/test/unit/traced_module/test_qat_module.py new file mode 100644 index 0000000000000000000000000000000000000000..721fdcd44b0021f058295b49c78fd3e25195950e --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_qat_module.py @@ -0,0 +1,195 @@ +import io +from functools import partial +from itertools import chain +from typing import Callable + +import numpy as np + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.quantization as Q +from megengine import Tensor +from megengine.module.qat.module import QATModule +from megengine.traced_module import TracedModule, trace_module + + +def get_subattr(self: M.Module, name: str): + if name == "": + return self + module_path, _, name = name.rpartition(".") + if module_path == "": + return getattr(self, name) + module_names = module_path.split(".") + for item in module_names: + self = getattr(self, item) + if not isinstance(self, M.Module): + raise AttributeError("`{}` is not an Module".format(item)) + return getattr(self, name) + + +class Myblcok(M.Module): + def __init__(self,): + super().__init__() + self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) + self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) + self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) + self.add = M.Elemwise("FUSE_ADD_RELU") + + def forward(self, x): + x = self.conv0(x) + x0 = self.conv1(x) + x1 = self.conv2(x) + o = self.add(x0, x1) + return o + + +class MyModule(M.Module): + def __init__(self): + super().__init__() + self.block0 = Myblcok() + self.block1 = Myblcok() + + def forward(self, x): + x = self.block0(x) + x = self.block1(x) + return x + + +class MyMinMaxObserver(Q.MinMaxObserver): + pass + + +class MyTQT(Q.TQT): + pass + + +def get_lsq_config(lsq_cls): + return Q.QConfig( + weight_observer=None, + act_observer=None, + weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"), + act_fake_quant=partial(lsq_cls, dtype="qint8"), + ) + + +def get_observer_config(observer_cls): + return Q.QConfig( + weight_observer=partial(observer_cls, dtype="qint8_narrow"), + act_observer=partial(observer_cls, dtype="qint8"), + weight_fake_quant=None, + act_fake_quant=None, + ) + + +def get_qparams(mod: QATModule): + weight_qparams, act_qparams = None, None + if mod.act_observer is not None: + act_qparams = mod.act_observer.get_qparams() + if mod.act_fake_quant: + act_qparams = mod.act_fake_quant.get_qparams() + + if mod.weight_observer is not None: + weight_qparams = mod.weight_observer.get_qparams() + if mod.weight_fake_quant: + weight_qparams = mod.weight_fake_quant.get_qparams() + + return weight_qparams, act_qparams + + +def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): + assert qparmsa.dtype_meta == qparmsb.dtype_meta + assert qparmsa.mode == qparmsb.mode + np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy()) + if qparmsa.zero_point is not None: + np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy()) + + +def build_observered_net(net: M.Module, observer_cls): + qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) + Q.enable_observer(qat_net) + for _ in range(5): + inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + qat_net(inp) + Q.disable_observer(qat_net) + return qat_net + + +def build_fakequanted_net(net: QATModule, fakequant_cls): + qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) + return qat_net + + +def test_trace_qat(): + def _check_qat_module(qat_net: QATModule): + inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + traced_net = trace_module(qat_net, inp) + + for name, qat_module in qat_net.named_modules(): + if not isinstance(qat_module, QATModule): + continue + traced_qat_module = get_subattr(traced_net, name) + weight_qparams, act_qparams = get_qparams(qat_module) + traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module) + if weight_qparams: + check_qparams(weight_qparams, traced_weight_qparams) + if act_qparams: + check_qparams(act_qparams, traced_act_qparams) + + _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) + _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) + _check_qat_module( + build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT) + ) + _check_qat_module( + build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT) + ) + + +def test_load_param(): + def _check_param(moda: M.Module, modb: M.Module): + for name, attr in chain(moda.named_parameters(), moda.named_buffers()): + traced_attr = get_subattr(modb, name) + np.testing.assert_equal(attr.numpy(), traced_attr.numpy()) + + def _check_module(build_func: Callable): + net = build_func() + buffer = io.BytesIO() + mge.save(net.state_dict(), buffer) + buffer.seek(0) + + inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + traced_net = trace_module(build_func(), inp) + traced_net.load_state_dict(mge.load(buffer)) + + _check_param(net, traced_net) + + buffer.seek(0) + traced_net = trace_module(build_func(), inp).flatten() + traced_net.load_state_dict(mge.load(buffer)) + + _check_param(net, traced_net) + + _check_module(lambda: MyModule()) + _check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver)) + + +def test_qualname(): + def _check_qualname(net): + inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + traced_net = trace_module(net, inp) + base_qualname = traced_net.graph.qualname + for node in traced_net.graph.nodes(): + qualname = node.qualname + qualname = qualname[len(base_qualname) + 1 :] + if qualname.endswith("]"): + qualname = qualname.rsplit(".", 1)[0] + if qualname.startswith("["): + qualname = "" + traced_attr = get_subattr(traced_net, qualname) + orig_attr = get_subattr(net, qualname) + assert traced_attr is not None + assert orig_attr is not None + + _check_qualname(MyModule()) + _check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver))