提交 15712807 编写于 作者: M Megvii Engine Team

feat(traced_module): add name to Node

GitOrigin-RevId: 39c28090678d0da23c313d594103405896e872ec
上级 e918f0aa
......@@ -11,6 +11,7 @@ import builtins
import collections
import copy
import inspect
import re
from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef
......@@ -21,7 +22,24 @@ from ...module import Module
from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef, tree_flatten
from .pytree import ArgsIndex, TreeDef, tree_flatten
def rstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
return s
def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s
def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s
class Expr:
......@@ -67,9 +85,29 @@ class Expr:
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
assert name is not None
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
else:
strip_method = strip(self.method, "_")
name = "%s_out" % strip_method
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"
for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
)
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
......@@ -133,11 +171,16 @@ class Input(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().add_input(expr.outputs[0])
oup_node = expr.outputs[0]
name = (
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node)
return expr.outputs[0]
def __repr__(self):
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)
return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name)
......@@ -154,22 +197,31 @@ class GetAttr(Expr):
self.name = name
node_cls = type if type else Node
self.outputs = [
node_cls(self),
node_cls(self, name=name),
]
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
module = expr.inputs[0]
oup_name = expr.name
while module._name != "self":
oup_name = module._name + "_" + oup_name
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._name = expr.name
return expr.outputs[0]
def interpret(self, *inputs):
return (getattr(inputs[0], self.name),)
def __repr__(self):
return '%{}: {} = GetAttr({}, "{}")'.format(
self._id, self.outputs[0], self.inputs[0], self.name
out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type
)
......@@ -230,11 +282,14 @@ class CallMethod(Expr):
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
method = ".%s" % self.method
if method == ".__call__":
method = ""
return "%{}:\t{}{}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0],
self.method,
method,
", ".join([args, kwargs]),
)
......@@ -259,7 +314,7 @@ class Apply(Expr):
return apply(self.opdef, *inputs)
def __repr__(self):
return "%{}: {} = {}({})".format(
return "%{}:\t{} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs),
self.opdef,
......@@ -314,10 +369,10 @@ class CallFunction(Expr):
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
return "%{}:\t{}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__,
self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
", ".join([args, kwargs]),
)
......@@ -328,21 +383,25 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}
def __init__(self, c):
def __init__(self, c, name=None):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
self.value = c
self.name = name
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self),
node_cls(self, name=name),
]
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name)
expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr)
return expr.outputs[0]
......@@ -352,8 +411,14 @@ class Constant(Expr):
return (self.value,)
def __repr__(self):
return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
name = self.name
if name is None:
name = type(self.value)
node_type = "Module"
if isinstance(self.outputs[0], TensorNode):
node_type = "Tensor"
return "%{}:\t{} = Constant({}) -> ({})".format(
self._id, self.outputs[0], name, node_type
)
def __getstate__(self):
......
......@@ -28,8 +28,9 @@ class Node:
expr = None
__total_id = 0
_id = None
_name = None
_top_graph = None # type: weakref.ReferenceType
_name = None
_format_spec = ""
def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
......@@ -43,10 +44,35 @@ class Node:
Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self):
if self._name is None:
return "%{}".format(self._id)
format_spec = Node._format_spec
return self.__format__(format_spec)
def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None:
format_spec = Node._format_spec
name = self._name
if name is None:
name = ""
if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec:
graph = self.top_graph
prefix_name = ""
if graph is not None:
prefix_name = graph._name
if graph._prefix_name:
prefix_name = "{}_{}".format(
graph._prefix_name, prefix_name.lstrip("_")
)
if name:
name = "_" + name.lstrip("_")
name = "{}{}".format(prefix_name, name)
if "i" in format_spec:
if name:
name = "_" + name.lstrip("_")
name = "%{}{}".format(self._id, name)
return name
else:
return "%{}".format(self._name)
return name if name else ("%d" % self._id)
@property
def top_graph(self):
......@@ -54,6 +80,12 @@ class Node:
return self._top_graph()
return None
@classmethod
def set_format_spec(cls, str):
old_format_spec = cls._format_spec
cls._format_spec = str
return old_format_spec
class ModuleNode(Node):
"""
......@@ -72,12 +104,6 @@ class ModuleNode(Node):
super().__init__(expr, name)
self.actual_mnode = []
def __repr__(self):
if self._name is None:
return "%{}_({})".format(self._id, self.module_type.__name__)
else:
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)
def __getstate__(self):
return {
"expr": self.expr,
......@@ -104,12 +130,6 @@ class TensorNode(Node):
qparam = None
device = None
def __repr__(self):
if self._name is None:
return "%{}_(Tensor)".format(self._id)
else:
return "%{}_{}(Tensor)".format(self._id, self._name)
def __getstate__(self):
return {
"expr": self.expr,
......@@ -119,6 +139,7 @@ class TensorNode(Node):
"shape": self.shape,
"dtype": self.dtype,
"device": self.device,
"_name": self._name,
}
......
......@@ -22,6 +22,16 @@ from ...quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode
class ArgsIndex:
def __init__(self, index=0, name="") -> None:
self.index = index
self.name = name
def __repr__(self) -> str:
return self.name
SUPPORTED_TYPE = {}
# if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
......@@ -39,6 +49,7 @@ SUPPORTED_LEAF_TYPE = {
type(None),
type(Ellipsis),
QuantMode,
ArgsIndex,
}
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
......@@ -121,11 +132,11 @@ def _is_leaf(obj):
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
return (Module, ModuleNode, NodeMixin, ArgsIndex)
else:
return type(node)
return (type(node), ArgsIndex)
def _is_const_leaf(node):
......
......@@ -6,12 +6,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections
import copy
import fnmatch
import functools
import inspect
import keyword
import re
import weakref
from inspect import getmembers, isclass, ismethod
from inspect import getcallargs, getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from ... import functional as F
......@@ -41,11 +44,19 @@ from .module_tracer import (
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten
from .pytree import ArgsIndex, tree_flatten
logger = get_logger(__name__)
def _is_builtin_name(name: str) -> bool:
return (
name in builtins.__dict__
or name in keyword.kwlist
or name in {"inf", "nan", "NoneType"}
)
def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
......@@ -67,6 +78,7 @@ class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
self.graph = graph
self.global_scope = InternalGraph()
self.global_scope._used_names.update(graph._used_names)
self.expr = expr
self.after = after
......@@ -91,6 +103,7 @@ class _InsertExprs:
for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr)
index += 1
self.graph._used_names.update(self.global_scope._used_names)
class InternalGraph:
......@@ -107,17 +120,37 @@ class InternalGraph:
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]
def __init__(self):
def __init__(self, name: str = None, prefix_name: str = ""):
self._exprs = []
self._inputs = []
self._outputs = []
self._watch_point = []
self._end_point = []
self._used_names = {}
self._rst = collections.defaultdict(list)
self._name = name
self._prefix_name = prefix_name
def insert(self, expr):
self._exprs.append(expr)
def _create_unique_name(self, name: str) -> str:
assert isinstance(name, str)
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = "_{}".format(name)
while name in self._used_names or _is_builtin_name(name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
else:
base, num = match.group(1, 2)
name = "{}_{}".format(base, int(num) + 1)
self._used_names.setdefault(name)
return name
@property
def inputs(self):
return self._inputs
......@@ -150,13 +183,16 @@ class InternalGraph:
def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id)
def get_node_by_name(self, name: str = None, ignorecase: bool = True):
return self.node_filter.name(name, ignorecase)
def add_input(self, i):
self._inputs.append(i)
def add_output(self, o):
self._outputs.append(o)
def _replace_inputs_outputs(self, repl_dict):
def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""):
for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
......@@ -175,13 +211,29 @@ class InternalGraph:
for expr in self._exprs:
for idx, i in enumerate(expr.inputs):
assert i._name is not None
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]
elif isinstance(i, TensorNode) and prefix_name not in i._name:
if i.top_graph != active_module_tracer().current_scope():
i._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + i._name.lstrip("_"))
)
for idx, o in enumerate(expr.outputs):
assert o._name is not None
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr
elif isinstance(o, TensorNode) and prefix_name not in i._name:
if o.top_graph != active_module_tracer().current_scope():
o._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + o._name.lstrip("_"))
)
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
......@@ -258,7 +310,7 @@ class InternalGraph:
# return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:]
def add_input_node(self, shape, dtype="float32"):
def add_input_node(self, shape, dtype="float32", name="args"):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode
......@@ -271,11 +323,11 @@ class InternalGraph:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append(c_expr)
def create_node(is_input: bool = True):
def create_node(name=None, is_input: bool = True):
if is_input:
node = Input(type=TensorNode).outputs[0]
node = Input(type=TensorNode, name=name).outputs[0]
else:
node = TensorNode(expr=None)
node = TensorNode(expr=None, name=None)
node.shape = shape
node.dtype = dtype
return node
......@@ -286,7 +338,7 @@ class InternalGraph:
org_argdef = call_nodes[0].arg_def
args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(True)
formal_inp_node = create_node(self._create_unique_name(name), True)
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
......@@ -524,11 +576,21 @@ class InternalGraph:
return self.interpret(*inp)
def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
return self.__format__()
def __format__(self, format_spec: str = "") -> str:
saved_format_spec = Node.set_format_spec(format_spec)
name = ""
if self._name:
name = "%s.Graph" % self._name
res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
name,
", ".join(str(i) for i in self._inputs),
"\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs),
)
Node.set_format_spec(saved_format_spec)
return res
def _get_meth_name(obj, func):
......@@ -621,6 +683,7 @@ class TracedModuleBuilder(NodeMixin):
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
......@@ -631,7 +694,7 @@ class TracedModuleBuilder(NodeMixin):
)
def build(self):
if self._is_builtin:
if self._is_builtin or isinstance(self._mod, TracedModule):
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
......@@ -671,21 +734,38 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def
if self._is_builtin:
if (
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
set_module_tracing()
if self._is_builtin:
self._body = None
elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.graph
name = NodeMixin.get(self)._name
if name:
self._body._name = name
else:
self_node = None
if tree_def in self._argdef_graph_map:
self_node = self._argdef_graph_map[tree_def].inputs[0]
self._body = InternalGraph()
orig_self = NodeMixin.get(self)
top_graph = active_module_tracer().current_scope()
graph_prefix_name = top_graph._name
if top_graph._prefix_name:
graph_prefix_name = "{}_{}".format(
top_graph._prefix_name, graph_prefix_name.lstrip("_")
)
self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name)
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
......@@ -698,16 +778,37 @@ class TracedModuleBuilder(NodeMixin):
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
index_args, index_kwargs = tree_def.unflatten(
[
ArgsIndex(0),
*list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
]
)
key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
idx2key = {}
for k, v in key2idx.items():
if isinstance(v, ArgsIndex):
idx2key[v.index] = k
else:
flatten_argidx, _ = tree_flatten(v)
for _i, v in enumerate(flatten_argidx):
if isinstance(v, ArgsIndex):
idx2key[v.index] = k + "_%d" % _i
def wrap(x, name):
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
x,
lambda: Input.make(
type=NodeMixin.get_wrapped_type(x), name=name
),
)
return x
args = [self]
for i in inputs[1:]:
args.append(wrap(i))
for i, v in enumerate(inputs[1:]):
args.append(wrap(v, idx2key[i + 1]))
args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
......@@ -857,6 +958,9 @@ class NodeFilter(BaseFilter):
def node_id(self, node_id: List[int]):
return NodeFilterNodeId(self, node_id)
def name(self, name: str, ignorecase: bool = True):
return NodeFilterName(self, name, ignorecase)
class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type):
......@@ -887,6 +991,33 @@ class NodeFilterNodeId(NodeFilter):
yield node
class NodeFilterName(NodeFilter):
_re = None
def __init__(self, node_iter, pattern, ignorecase):
super().__init__(node_iter)
self.pattern = pattern
self._re = self.make_re(pattern, ignorecase)
@classmethod
def make_re(cls, pattern, ignorecase=True):
assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
assert isinstance(ignorecase, bool)
flags = 0
if ignorecase:
flags |= re.IGNORECASE
return re.compile(fnmatch.translate(pattern), flags=flags)
def __iter__(self):
for i in self._iter:
graph = i.top_graph
name = "{}_{}".format(graph._name, i._name.lstrip("_"))
if graph._prefix_name:
name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
if self.pattern == name or self._re.match(name):
yield i
class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter)
......@@ -1052,12 +1183,29 @@ class TracedModule(Module):
:return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)
module2name = {}
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(lambda x: x))
active_module_tracer().push_scope(new_module.graph)
for n, m in new_module.named_modules():
module2name[id(m)] = n
def _flatten_subgraph(graph, module, call=None):
def _flatten_subgraph(
graph: InternalGraph, module: Module, call=None, prefix_name=""
):
if graph is not None and prefix_name and prefix_name[-1] != "_":
prefix_name += "_"
if graph is None:
assert not isinstance(module, TracedModule)
const = Constant(module)
const.outputs[0] = call.inputs[0]
const = Constant(module, "self.%s" % module2name[id(module)])
m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope():
m_node._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name)
)
const.outputs[0] = m_node
const.outputs[0].expr = const
return [const, call]
if call is not None:
......@@ -1083,7 +1231,7 @@ class TracedModule(Module):
continue
repl_dict[out] = call.outputs[ind]
graph._replace_inputs_outputs(repl_dict)
graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name)
for expr in graph._exprs:
if isinstance(expr, GetAttr):
......@@ -1109,7 +1257,14 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
exprs.extend(
_flatten_subgraph(
expr_graph,
obj,
expr,
prefix_name + obj_node._name.lstrip("_"),
)
)
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
......@@ -1126,7 +1281,18 @@ class TracedModule(Module):
return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
new_module.graph.compile()
set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs):
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
return new_module
def __getstate__(self):
......@@ -1149,19 +1315,7 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls)
def wrap(func: Union[Callable]):
assert callable(func)
if hasattr(func, "__code__"):
assert not isinstance(func, str)
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != "<module>":
raise NotImplementedError("wrap must be called at the top level of a module")
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func
wrap = _wrapped_function
def _register_all_builtin_module():
......@@ -1192,11 +1346,11 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher:
global_scope = InternalGraph()
global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
name = mod._name if mod._name else mod.__class__.__name__
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册