From 15712807b9866f3675cfc42315123d4186364636 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 6 Aug 2021 17:55:16 +0800 Subject: [PATCH] feat(traced_module): add name to Node GitOrigin-RevId: 39c28090678d0da23c313d594103405896e872ec --- .../experimental/traced_module/expr.py | 99 +++++-- .../experimental/traced_module/node.py | 53 ++-- .../experimental/traced_module/pytree.py | 17 +- .../traced_module/traced_module.py | 242 ++++++++++++++---- 4 files changed, 331 insertions(+), 80 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 1f8ff6859..5b6b1ed63 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -11,6 +11,7 @@ import builtins import collections import copy import inspect +import re from typing import Callable, Dict, List from ...core._imperative_rt import OpDef @@ -21,7 +22,24 @@ from ...module import Module from ...tensor import Parameter, Tensor from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode -from .pytree import TreeDef, tree_flatten +from .pytree import ArgsIndex, TreeDef, tree_flatten + + +def rstrip(s: str, __chars: str): + __chars = re.escape(__chars) + s = re.sub(r"^(?P.*?)(?:%s)+$" % __chars, "\g", s) + return s + + +def lstrip(s: str, __chars: str): + __chars = re.escape(__chars) + s = re.sub(r"^(?:%s)+(?P.*)$" % __chars, "\g", s) + return s + + +def strip(s: str, __chars: str): + s = lstrip(rstrip(s, __chars), __chars) + return s class Expr: @@ -67,9 +85,29 @@ class Expr: if not isinstance(outputs, collections.Sequence): outputs = (outputs,) + name = None + if isinstance(self, CallMethod): + name = self.inputs[0]._name + assert name is not None + name = rstrip(name, "_out") + if self.method == "__call__": + name += "_out" + else: + strip_method = strip(self.method, "_") + name = "%s_out" % strip_method + elif isinstance(self, CallFunction): + name = self.func.__name__ + "_out" + elif isinstance(self, Apply): + name = str(self.opdef).lower() + "_out" + for i in outputs: assert isinstance(i, RawTensor) - self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + o_name = ( + active_module_tracer().current_scope()._create_unique_name(name) + ) + self.outputs.append( + NodeMixin.get_wrapped_type(i)(expr=self, name=o_name) + ) for i, node in zip(outputs, self.outputs,): NodeMixin.wrap_safe(i, node) @@ -133,11 +171,16 @@ class Input(Expr): @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) - active_module_tracer().current_scope().add_input(expr.outputs[0]) + oup_node = expr.outputs[0] + name = ( + active_module_tracer().current_scope()._create_unique_name(oup_node._name) + ) + oup_node._name = name + active_module_tracer().current_scope().add_input(oup_node) return expr.outputs[0] def __repr__(self): - return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name) + return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name) @@ -154,22 +197,31 @@ class GetAttr(Expr): self.name = name node_cls = type if type else Node self.outputs = [ - node_cls(self), + node_cls(self, name=name), ] @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) + module = expr.inputs[0] + oup_name = expr.name + while module._name != "self": + oup_name = module._name + "_" + oup_name + module = module.expr.inputs[0] + oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) + expr.outputs[0]._name = oup_name active_module_tracer().current_scope().insert(expr) - expr.outputs[0]._name = expr.name return expr.outputs[0] def interpret(self, *inputs): return (getattr(inputs[0], self.name),) def __repr__(self): - return '%{}: {} = GetAttr({}, "{}")'.format( - self._id, self.outputs[0], self.inputs[0], self.name + out_type = "Tensor" + if isinstance(self.outputs[0], ModuleNode): + out_type = self.outputs[0].module_type.__name__ + return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( + self._id, self.outputs[0], self.inputs[0], self.name, out_type ) @@ -230,11 +282,14 @@ class CallMethod(Expr): outputs = self.outputs if self.out_def: outputs = self.out_def.unflatten(outputs) - return "%{}: {}{}.{}({})".format( + method = ".%s" % self.method + if method == ".__call__": + method = "" + return "%{}:\t{}{}{}({})".format( self._id, str(outputs) + " = " if outputs else "", self.args[0], - self.method, + method, ", ".join([args, kwargs]), ) @@ -259,7 +314,7 @@ class Apply(Expr): return apply(self.opdef, *inputs) def __repr__(self): - return "%{}: {} = {}({})".format( + return "%{}:\t{} = {}({})".format( self._id, ", ".join(str(i) for i in self.outputs), self.opdef, @@ -314,10 +369,10 @@ class CallFunction(Expr): outputs = self.outputs if self.out_def: outputs = self.out_def.unflatten(outputs) - return "%{}: {}{}({})".format( + return "%{}:\t{}{}({})".format( self._id, str(outputs) + " = " if outputs else "", - self.func.__module__ + "." + self.func.__name__, + self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__, ", ".join([args, kwargs]), ) @@ -328,21 +383,25 @@ class Constant(Expr): # TODO: constant cache to reduce the size of dumped model _constant_cache = {} - def __init__(self, c): + def __init__(self, c, name=None): super().__init__() assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): assert module_tracer.is_builtin(c) self.value = c + self.name = name self.inputs = [] node_cls = NodeMixin.get_wrapped_type(c) self.outputs = [ - node_cls(self), + node_cls(self, name=name), ] @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) + name = "const_module" if isinstance(expr.value, Module) else "const_tensor" + name = active_module_tracer().current_scope()._create_unique_name(name) + expr.outputs[0]._name = name active_module_tracer().current_scope().insert(expr) return expr.outputs[0] @@ -352,8 +411,14 @@ class Constant(Expr): return (self.value,) def __repr__(self): - return "%{}: {} = Constant({})".format( - self._id, self.outputs[0], type(self.value) + name = self.name + if name is None: + name = type(self.value) + node_type = "Module" + if isinstance(self.outputs[0], TensorNode): + node_type = "Tensor" + return "%{}:\t{} = Constant({}) -> ({})".format( + self._id, self.outputs[0], name, node_type ) def __getstate__(self): diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index ae7160ce9..2c43c8958 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -28,8 +28,9 @@ class Node: expr = None __total_id = 0 _id = None - _name = None _top_graph = None # type: weakref.ReferenceType + _name = None + _format_spec = "" def __init__(self, expr: "Expr", name: str = None): self.expr = expr @@ -43,10 +44,35 @@ class Node: Node.__total_id = max(Node.__total_id, self._id) + 1 def __repr__(self): - if self._name is None: - return "%{}".format(self._id) + format_spec = Node._format_spec + return self.__format__(format_spec) + + def __format__(self, format_spec: str) -> str: + if format_spec == "" or format_spec is None: + format_spec = Node._format_spec + name = self._name + if name is None: + name = "" + if format_spec in ["i", "p", "ip", "pi"]: + if "p" in format_spec: + graph = self.top_graph + prefix_name = "" + if graph is not None: + prefix_name = graph._name + if graph._prefix_name: + prefix_name = "{}_{}".format( + graph._prefix_name, prefix_name.lstrip("_") + ) + if name: + name = "_" + name.lstrip("_") + name = "{}{}".format(prefix_name, name) + if "i" in format_spec: + if name: + name = "_" + name.lstrip("_") + name = "%{}{}".format(self._id, name) + return name else: - return "%{}".format(self._name) + return name if name else ("%d" % self._id) @property def top_graph(self): @@ -54,6 +80,12 @@ class Node: return self._top_graph() return None + @classmethod + def set_format_spec(cls, str): + old_format_spec = cls._format_spec + cls._format_spec = str + return old_format_spec + class ModuleNode(Node): """ @@ -72,12 +104,6 @@ class ModuleNode(Node): super().__init__(expr, name) self.actual_mnode = [] - def __repr__(self): - if self._name is None: - return "%{}_({})".format(self._id, self.module_type.__name__) - else: - return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__) - def __getstate__(self): return { "expr": self.expr, @@ -104,12 +130,6 @@ class TensorNode(Node): qparam = None device = None - def __repr__(self): - if self._name is None: - return "%{}_(Tensor)".format(self._id) - else: - return "%{}_{}(Tensor)".format(self._id, self._name) - def __getstate__(self): return { "expr": self.expr, @@ -119,6 +139,7 @@ class TensorNode(Node): "shape": self.shape, "dtype": self.dtype, "device": self.device, + "_name": self._name, } diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 5abb92be6..8a526a9a7 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -22,6 +22,16 @@ from ...quantization.utils import LSQParams, QParams, QuantMode from ...tensor import Parameter, Tensor from .node import ModuleNode, Node, NodeMixin, TensorNode + +class ArgsIndex: + def __init__(self, index=0, name="") -> None: + self.index = index + self.name = name + + def __repr__(self) -> str: + return self.name + + SUPPORTED_TYPE = {} # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree @@ -39,6 +49,7 @@ SUPPORTED_LEAF_TYPE = { type(None), type(Ellipsis), QuantMode, + ArgsIndex, } # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree @@ -121,11 +132,11 @@ def _is_leaf(obj): def _leaf_type(node): if isinstance(node, (RawTensor, TensorNode)): - return (Tensor, TensorNode) + return (Tensor, TensorNode, ArgsIndex) elif isinstance(node, (NodeMixin, Module)): - return (Module, ModuleNode, NodeMixin) + return (Module, ModuleNode, NodeMixin, ArgsIndex) else: - return type(node) + return (type(node), ArgsIndex) def _is_const_leaf(node): diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 030894008..bb55c001d 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -6,12 +6,15 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import builtins import collections import copy +import fnmatch import functools -import inspect +import keyword +import re import weakref -from inspect import getmembers, isclass, ismethod +from inspect import getcallargs, getmembers, isclass, ismethod from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from ... import functional as F @@ -41,11 +44,19 @@ from .module_tracer import ( set_active_module_tracer, ) from .node import ModuleNode, Node, NodeMixin, TensorNode -from .pytree import tree_flatten +from .pytree import ArgsIndex, tree_flatten logger = get_logger(__name__) +def _is_builtin_name(name: str) -> bool: + return ( + name in builtins.__dict__ + or name in keyword.kwlist + or name in {"inf", "nan", "NoneType"} + ) + + def _is_leaf(node): assert isinstance(node, RawTensor), "doesn't support {} in return values".format( type(node) @@ -67,6 +78,7 @@ class _InsertExprs: def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): self.graph = graph self.global_scope = InternalGraph() + self.global_scope._used_names.update(graph._used_names) self.expr = expr self.after = after @@ -91,6 +103,7 @@ class _InsertExprs: for expr in self.global_scope._exprs: self.graph._exprs.insert(index, expr) index += 1 + self.graph._used_names.update(self.global_scope._used_names) class InternalGraph: @@ -107,17 +120,37 @@ class InternalGraph: _inputs = None # type: List[Node] _outputs = None # type: List[Node] - def __init__(self): + def __init__(self, name: str = None, prefix_name: str = ""): self._exprs = [] self._inputs = [] self._outputs = [] self._watch_point = [] self._end_point = [] + self._used_names = {} self._rst = collections.defaultdict(list) + self._name = name + self._prefix_name = prefix_name def insert(self, expr): self._exprs.append(expr) + def _create_unique_name(self, name: str) -> str: + assert isinstance(name, str) + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = "_{}".format(name) + + while name in self._used_names or _is_builtin_name(name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = "{}_{}".format(base, int(num) + 1) + + self._used_names.setdefault(name) + return name + @property def inputs(self): return self._inputs @@ -150,13 +183,16 @@ class InternalGraph: def get_node_by_id(self, node_id: List[int] = None): return self.node_filter.node_id(node_id) + def get_node_by_name(self, name: str = None, ignorecase: bool = True): + return self.node_filter.name(name, ignorecase) + def add_input(self, i): self._inputs.append(i) def add_output(self, o): self._outputs.append(o) - def _replace_inputs_outputs(self, repl_dict): + def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""): for node, repl_node in repl_dict.items(): assert node in self._inputs or node in self._outputs @@ -175,13 +211,29 @@ class InternalGraph: for expr in self._exprs: for idx, i in enumerate(expr.inputs): + assert i._name is not None if i in repl_dict: expr.inputs[idx] = repl_dict[i] + elif isinstance(i, TensorNode) and prefix_name not in i._name: + if i.top_graph != active_module_tracer().current_scope(): + i._name = ( + active_module_tracer() + .current_scope() + ._create_unique_name(prefix_name + i._name.lstrip("_")) + ) for idx, o in enumerate(expr.outputs): + assert o._name is not None if o in repl_dict: expr.outputs[idx] = repl_dict[o] expr.outputs[idx].expr = expr + elif isinstance(o, TensorNode) and prefix_name not in i._name: + if o.top_graph != active_module_tracer().current_scope(): + o._name = ( + active_module_tracer() + .current_scope() + ._create_unique_name(prefix_name + o._name.lstrip("_")) + ) def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: if not isinstance(nodes, Sequence): @@ -258,7 +310,7 @@ class InternalGraph: # return formal_node_inputs[1:], actual_nodes return formal_node_inputs[1:] - def add_input_node(self, shape, dtype="float32"): + def add_input_node(self, shape, dtype="float32", name="args"): forma_mnode = self.inputs[0] actual_mnodes = forma_mnode.actual_mnode @@ -271,11 +323,11 @@ class InternalGraph: if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": call_nodes.append(c_expr) - def create_node(is_input: bool = True): + def create_node(name=None, is_input: bool = True): if is_input: - node = Input(type=TensorNode).outputs[0] + node = Input(type=TensorNode, name=name).outputs[0] else: - node = TensorNode(expr=None) + node = TensorNode(expr=None, name=None) node.shape = shape node.dtype = dtype return node @@ -286,7 +338,7 @@ class InternalGraph: org_argdef = call_nodes[0].arg_def args, kwargs = org_argdef.unflatten(self._inputs) - formal_inp_node = create_node(True) + formal_inp_node = create_node(self._create_unique_name(name), True) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), @@ -524,11 +576,21 @@ class InternalGraph: return self.interpret(*inp) def __repr__(self): - return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( + return self.__format__() + + def __format__(self, format_spec: str = "") -> str: + saved_format_spec = Node.set_format_spec(format_spec) + name = "" + if self._name: + name = "%s.Graph" % self._name + res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format( + name, ", ".join(str(i) for i in self._inputs), "\n\t".join("{}".format(str(i)) for i in self._exprs), ", ".join(str(i) for i in self._outputs), ) + Node.set_format_spec(saved_format_spec) + return res def _get_meth_name(obj, func): @@ -621,6 +683,7 @@ class TracedModuleBuilder(NodeMixin): self._is_builtin = module_tracer.is_builtin(mod) self._argdef_graph_map = {} self._argdef_outdef_map = {} + self.nodes = set() # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. @@ -631,7 +694,7 @@ class TracedModuleBuilder(NodeMixin): ) def build(self): - if self._is_builtin: + if self._is_builtin or isinstance(self._mod, TracedModule): for node in self.nodes: node.module_type = type(self._mod) # node._owner = weakref.ref(self._mod) @@ -671,21 +734,38 @@ class TracedModuleBuilder(NodeMixin): callnode.arg_def = tree_def - if self._is_builtin: + if ( + self._is_builtin + or tree_def in self._argdef_graph_map + or isinstance(self._mod, TracedModule) + ): unset_module_tracing() rst = self._mod(*args, **kwargs) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) set_module_tracing() if self._is_builtin: self._body = None + elif tree_def in self._argdef_graph_map: + self._body = self._argdef_graph_map[tree_def] + else: + self._mod._is_top = False + self._body = self._mod.graph + name = NodeMixin.get(self)._name + if name: + self._body._name = name else: self_node = None - if tree_def in self._argdef_graph_map: - self_node = self._argdef_graph_map[tree_def].inputs[0] - self._body = InternalGraph() + orig_self = NodeMixin.get(self) + top_graph = active_module_tracer().current_scope() + graph_prefix_name = top_graph._name + if top_graph._prefix_name: + graph_prefix_name = "{}_{}".format( + top_graph._prefix_name, graph_prefix_name.lstrip("_") + ) + self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name) active_module_tracer().push_scope(self._body) # rebind self to new input node - orig_self = NodeMixin.get(self) + if self_node: NodeMixin.wrap_safe(self, self_node) active_module_tracer().current_scope().add_input(self_node) @@ -698,16 +778,37 @@ class TracedModuleBuilder(NodeMixin): ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph - def wrap(x): + index_args, index_kwargs = tree_def.unflatten( + [ + ArgsIndex(0), + *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))), + ] + ) + key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs) + idx2key = {} + for k, v in key2idx.items(): + if isinstance(v, ArgsIndex): + idx2key[v.index] = k + else: + flatten_argidx, _ = tree_flatten(v) + for _i, v in enumerate(flatten_argidx): + if isinstance(v, ArgsIndex): + idx2key[v.index] = k + "_%d" % _i + + def wrap(x, name): if isinstance(x, (RawTensor, NodeMixin)): NodeMixin.wrap( - x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), + x, + lambda: Input.make( + type=NodeMixin.get_wrapped_type(x), name=name + ), ) return x args = [self] - for i in inputs[1:]: - args.append(wrap(i)) + for i, v in enumerate(inputs[1:]): + args.append(wrap(v, idx2key[i + 1])) + args, kwargs = tree_def.unflatten(args) active_module_tracer().patcher.auto_patch( getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) @@ -857,6 +958,9 @@ class NodeFilter(BaseFilter): def node_id(self, node_id: List[int]): return NodeFilterNodeId(self, node_id) + def name(self, name: str, ignorecase: bool = True): + return NodeFilterName(self, name, ignorecase) + class NodeFilterType(NodeFilter): def __init__(self, expr_iter, owner_type, node_type): @@ -887,6 +991,33 @@ class NodeFilterNodeId(NodeFilter): yield node +class NodeFilterName(NodeFilter): + _re = None + + def __init__(self, node_iter, pattern, ignorecase): + super().__init__(node_iter) + self.pattern = pattern + self._re = self.make_re(pattern, ignorecase) + + @classmethod + def make_re(cls, pattern, ignorecase=True): + assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) + assert isinstance(ignorecase, bool) + flags = 0 + if ignorecase: + flags |= re.IGNORECASE + return re.compile(fnmatch.translate(pattern), flags=flags) + + def __iter__(self): + for i in self._iter: + graph = i.top_graph + name = "{}_{}".format(graph._name, i._name.lstrip("_")) + if graph._prefix_name: + name = "{}_{}".format(graph._prefix_name, name.lstrip("_")) + if self.pattern == name or self._re.match(name): + yield i + + class ExprFilterCallFunction(ExprFilter): def __init__(self, expr_iter, func: Callable = None): super().__init__(expr_iter) @@ -1052,12 +1183,29 @@ class TracedModule(Module): :return: :class:`TracedModule` """ new_module = copy.deepcopy(self) - - def _flatten_subgraph(graph, module, call=None): + module2name = {} + assert active_module_tracer() is None + set_active_module_tracer(module_tracer(lambda x: x)) + active_module_tracer().push_scope(new_module.graph) + for n, m in new_module.named_modules(): + module2name[id(m)] = n + + def _flatten_subgraph( + graph: InternalGraph, module: Module, call=None, prefix_name="" + ): + if graph is not None and prefix_name and prefix_name[-1] != "_": + prefix_name += "_" if graph is None: assert not isinstance(module, TracedModule) - const = Constant(module) - const.outputs[0] = call.inputs[0] + const = Constant(module, "self.%s" % module2name[id(module)]) + m_node = call.inputs[0] + if m_node.top_graph != active_module_tracer().current_scope(): + m_node._name = ( + active_module_tracer() + .current_scope() + ._create_unique_name(prefix_name) + ) + const.outputs[0] = m_node const.outputs[0].expr = const return [const, call] if call is not None: @@ -1083,7 +1231,7 @@ class TracedModule(Module): continue repl_dict[out] = call.outputs[ind] - graph._replace_inputs_outputs(repl_dict) + graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name) for expr in graph._exprs: if isinstance(expr, GetAttr): @@ -1109,7 +1257,14 @@ class TracedModule(Module): if hasattr(obj, "argdef_graph_map") else None ) - exprs.extend(_flatten_subgraph(expr_graph, obj, expr)) + exprs.extend( + _flatten_subgraph( + expr_graph, + obj, + expr, + prefix_name + obj_node._name.lstrip("_"), + ) + ) else: # module has been replaced. assert isinstance(pre_expr, Constant) @@ -1126,7 +1281,18 @@ class TracedModule(Module): return exprs new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) - + new_module.graph.compile() + set_active_module_tracer(None) + for _id, expr in enumerate(new_module.graph._exprs): + expr._id = _id + total_node_id = 0 + for i in new_module.graph._inputs: + i._id = total_node_id + total_node_id += 1 + for expr in new_module.graph._exprs: + for o in expr.outputs: + o._id = total_node_id + total_node_id += 1 return new_module def __getstate__(self): @@ -1149,19 +1315,7 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: module_tracer.register_as_builtin(mod_cls) -def wrap(func: Union[Callable]): - assert callable(func) - if hasattr(func, "__code__"): - assert not isinstance(func, str) - fn_name = func.__code__.co_name - currentframe = inspect.currentframe() - assert currentframe is not None - f = currentframe.f_back - assert f is not None - if f.f_code.co_name != "": - raise NotImplementedError("wrap must be called at the top level of a module") - Patcher._builtin_functions.append((f.f_globals, fn_name)) - return func +wrap = _wrapped_function def _register_all_builtin_module(): @@ -1192,11 +1346,11 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: set_active_module_tracer(module_tracer(_wrapped_function)) with active_module_tracer().patcher: - global_scope = InternalGraph() + global_scope = InternalGraph(name="") active_module_tracer().push_scope(global_scope) - builder = TracedModuleBuilder(mod, True) - NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) + name = mod._name if mod._name else mod.__class__.__name__ + NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode)) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " -- GitLab