提交 fb20cb36 编写于 作者: M Megvii Engine Team

docs(mge/traced_module): update traced_module api doc

GitOrigin-RevId: 19a95d26c71e672376c5fda00a4e7dc6050e1c6a
上级 c7a8d945
...@@ -130,3 +130,4 @@ import megengine.optimizer ...@@ -130,3 +130,4 @@ import megengine.optimizer
import megengine.quantization import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.traced_module
...@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str): ...@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str):
class Expr: class Expr:
"""``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.""" r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
"""
__total_id = 0
inputs = None # type: List[Node] inputs = None # type: List[Node]
r"""The input Nodes of this Expr."""
outputs = None # type: List[Node] outputs = None # type: List[Node]
r"""The output Nodes of this Expr."""
const_val = None # type: List[Any] const_val = None # type: List[Any]
r"""The non-tensor object in the input of the operation."""
arg_def = None # type: TreeDef arg_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
out_def = None # type: TreeDef out_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
__total_id = 0
def __init__(self) -> None: def __init__(self) -> None:
self._id = Expr.__total_id self._id = Expr.__total_id
...@@ -125,6 +132,11 @@ class Expr: ...@@ -125,6 +132,11 @@ class Expr:
return inputs, {} return inputs, {}
def replace_inputs(self, repl_dict: Dict[Node, Node]): def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
"""
while repl_dict: while repl_dict:
node, repl_node = repl_dict.popitem() node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node) assert type(node) == type(repl_node)
...@@ -147,16 +159,19 @@ class Expr: ...@@ -147,16 +159,19 @@ class Expr:
@property @property
def kwargs(self): def kwargs(self):
r"""Get the the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
return kwargs return kwargs
@property @property
def args(self): def args(self):
r"""Get the the positional arguments of the operation corresponding to this Expr."""
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Expr."""
if self._top_graph: if self._top_graph:
return self._top_graph() return self._top_graph()
return None return None
...@@ -168,17 +183,18 @@ class Expr: ...@@ -168,17 +183,18 @@ class Expr:
return state return state
@classmethod @classmethod
def get_total_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
@classmethod @classmethod
def set_total_id(cls, id: int = 0): def _set_next_id(cls, id: int = 0):
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
r"""A fake Expr which is used to mark the input of graph."""
name = None name = None
def __init__(self, name=None, type=None, orig_name=None): def __init__(self, name=None, type=None, orig_name=None):
...@@ -204,13 +220,15 @@ class Input(Expr): ...@@ -204,13 +220,15 @@ class Input(Expr):
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name) return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr): class GetAttr(Expr):
name = None r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
name = None
r"""name: the qualified name of the attribute to be retrieved."""
def __init__(self, module, name, type=None, orig_name=None): def __init__(self, module, name, type=None, orig_name=None):
super().__init__() super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
...@@ -251,6 +269,13 @@ class GetAttr(Expr): ...@@ -251,6 +269,13 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
Args:
node: the Node to be called.
method: the method name.
Default: "__call__"
"""
def __init__(self, node, method="__call__"): def __init__(self, node, method="__call__"):
super().__init__() super().__init__()
if isinstance(node, type): if isinstance(node, type):
...@@ -320,8 +345,12 @@ class CallMethod(Expr): ...@@ -320,8 +345,12 @@ class CallMethod(Expr):
# expr: outputs = apply(self.opdef, *inputs) # expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr): class Apply(Expr):
opdef = None r"""``Apply`` represents a call to :func:`apply`.
Args:
opdef: the applied :class:`OpDef`.
"""
opdef = None
def __init__(self, opdef): def __init__(self, opdef):
super().__init__() super().__init__()
assert isinstance(opdef, OpDef) assert isinstance(opdef, OpDef)
...@@ -388,6 +417,11 @@ class Apply(Expr): ...@@ -388,6 +417,11 @@ class Apply(Expr):
class CallFunction(Expr): class CallFunction(Expr):
r"""``CallFunction`` represents a call to a built-in function.
Args:
func: a built-in function.
"""
def __init__(self, func): def __init__(self, func):
super().__init__() super().__init__()
assert isinstance(func, Callable) assert isinstance(func, Callable)
...@@ -425,7 +459,14 @@ class CallFunction(Expr): ...@@ -425,7 +459,14 @@ class CallFunction(Expr):
# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
Args:
c: a const Tensor or Module.
name: the name of output Node.
"""
value = None value = None
r"""The const Tensor or Module"""
# TODO: constant cache to reduce the size of dumped model # TODO: constant cache to reduce the size of dumped model
_constant_cache = {} _constant_cache = {}
......
...@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor ...@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin): class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
def __init__( def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
): ):
...@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin): ...@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
return self.qparams return self.qparams
def set_qparams(self, qparams: QParams): def set_qparams(self, qparams: QParams):
r""" r"""Initialize :attr:`~.FakeQuantize.qparams`.
Args: Args:
qparams: used to set initial scale. qparams: used to set initial ``scale`` and ``zero_point``.
""" """
if qparams.scale is None: if qparams.scale is None:
raise AssertionError("Can not get an initialized scale") raise AssertionError("Can not get an initialized scale")
......
...@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type ...@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type
import numpy import numpy
from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module from ..module import Module
from ..tensor import Tensor from ..tensor import Tensor
logger = get_logger(__name__)
class Node:
r"""``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method.
They are inputs/outputs of Expr(the operations on variables).
Args: class Node:
expr: the Expr which produces the node r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
name: the name of the node They are inputs/outputs of Expr (the operations on variables).
""" """
expr = None expr = None # type: Expr
__total_id = 0 r"""The Expr which produces the Node."""
_id = None __total_id = 0 # type: int
_id = None # type: int
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None _name = None # type: str
_orig_name = None _orig_name = None # type: str
_format_spec = "" _format_spec = "" # type: str
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): def __init__(self, expr: "Expr", name: str, orig_name: str):
self.expr = expr self.expr = expr
self.users = [] # List[Expr] self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
...@@ -73,24 +73,42 @@ class Node: ...@@ -73,24 +73,42 @@ class Node:
else: else:
return name if name else ("%d" % self._id) return name if name else ("%d" % self._id)
@property
def name(self):
r"""Return the name of this Node."""
return self._name
@name.setter
def name(self, new_name: str):
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._used_names, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = graph._create_unique_name(new_name)
self._name = new_name
self._orig_name = new_name
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Node."""
if self._top_graph: if self._top_graph:
return self._top_graph() return self._top_graph()
return None return None
@classmethod @classmethod
def set_format_spec(cls, str): def _set_format_spec(cls, str):
old_format_spec = cls._format_spec old_format_spec = cls._format_spec
cls._format_spec = str cls._format_spec = str
return old_format_spec return old_format_spec
@classmethod @classmethod
def get_total_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
@classmethod @classmethod
def set_total_id(cls, id: int = 0): def _set_next_id(cls, id: int = 0):
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
...@@ -99,6 +117,7 @@ class ModuleNode(Node): ...@@ -99,6 +117,7 @@ class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects.""" r"""``ModuleNode`` represents the Module objects."""
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType _owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
...@@ -116,6 +135,11 @@ class ModuleNode(Node): ...@@ -116,6 +135,11 @@ class ModuleNode(Node):
@property @property
def owner(self): def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``.
Returns:
An :calss:`~.Module`.
"""
if self._owner: if self._owner:
return self._owner() return self._owner()
return None return None
...@@ -145,6 +169,7 @@ class TensorNode(Node): ...@@ -145,6 +169,7 @@ class TensorNode(Node):
@property @property
def shape(self): def shape(self):
r"""Get the shape of this Node."""
return self._shape return self._shape
@shape.setter @shape.setter
...@@ -153,6 +178,7 @@ class TensorNode(Node): ...@@ -153,6 +178,7 @@ class TensorNode(Node):
@property @property
def dtype(self): def dtype(self):
r"""Get the dtype of this Node."""
return self._dtype return self._dtype
@dtype.setter @dtype.setter
...@@ -161,6 +187,7 @@ class TensorNode(Node): ...@@ -161,6 +187,7 @@ class TensorNode(Node):
@property @property
def device(self): def device(self):
r"""Get the device of this Node pointed Tensor."""
return self._device return self._device
@device.setter @device.setter
...@@ -169,6 +196,7 @@ class TensorNode(Node): ...@@ -169,6 +196,7 @@ class TensorNode(Node):
@property @property
def qparams(self): def qparams(self):
r"""Get the :calss:`QParams` of this Node."""
return self._qparams return self._qparams
@qparams.setter @qparams.setter
...@@ -177,10 +205,16 @@ class TensorNode(Node): ...@@ -177,10 +205,16 @@ class TensorNode(Node):
@property @property
def value(self): def value(self):
r"""Get the bound Tensor of this Node."""
return self._value return self._value
@value.setter @value.setter
def value(self, value): def value(self, value):
r"""Bind a Tensor to this Node.
Args:
value: A :class:`Tensor`.
"""
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None) setattr(value, "_NodeMixin__node", None)
self._value = value self._value = value
......
...@@ -150,6 +150,9 @@ def tree_flatten( ...@@ -150,6 +150,9 @@ def tree_flatten(
is_leaf: Callable = _is_leaf, is_leaf: Callable = _is_leaf,
is_const_leaf: Callable = _is_const_leaf, is_const_leaf: Callable = _is_const_leaf,
): ):
r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used
to reconstruct the object.
"""
if type(values) not in SUPPORTED_TYPE: if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values assert is_leaf(values), values
node = LeafDef(leaf_type(values)) node = LeafDef(leaf_type(values))
...@@ -169,6 +172,15 @@ def tree_flatten( ...@@ -169,6 +172,15 @@ def tree_flatten(
class TreeDef: class TreeDef:
r"""A ``TreeDef`` represents the structure of a pytree.
Args:
type: the type of root Node of the pytree.
aux_data: some const data that is useful in unflattening the pytree.
children_defs: ``TreeDef`` for each child of the root Node.
num_leaves: the number of leaves.
"""
def __init__(self, type, aux_data, children_defs): def __init__(self, type, aux_data, children_defs):
self.type = type self.type = type
self.aux_data = aux_data self.aux_data = aux_data
...@@ -176,6 +188,9 @@ class TreeDef: ...@@ -176,6 +188,9 @@ class TreeDef:
self.num_leaves = sum(ch.num_leaves for ch in children_defs) self.num_leaves = sum(ch.num_leaves for ch in children_defs)
def unflatten(self, leaves): def unflatten(self, leaves):
r"""Given a list of values and a ``TreeDef``, builds a object.
This is the inverse operation of ``tree_flatten``.
"""
assert len(leaves) == self.num_leaves assert len(leaves) == self.num_leaves
start = 0 start = 0
children = [] children = []
...@@ -196,13 +211,10 @@ class TreeDef: ...@@ -196,13 +211,10 @@ class TreeDef:
) )
) )
def __lt__(self, other): def __ne__(self, other) -> bool:
return self.__hash__() < other.__hash__() return not self.__eq__(other)
def __gt__(self, other):
return self.__hash__() > other.__hash__()
def __eq__(self, other): def __eq__(self, other) -> bool:
return ( return (
self.type == other.type self.type == other.type
and self.aux_data == other.aux_data and self.aux_data == other.aux_data
...@@ -227,6 +239,9 @@ class LeafDef(TreeDef): ...@@ -227,6 +239,9 @@ class LeafDef(TreeDef):
assert isinstance(leaves[0], self.type), self.type assert isinstance(leaves[0], self.type), self.type
return leaves[0] return leaves[0]
def __ne__(self, other) -> bool:
return not self.__eq__(other)
def __eq__(self, other): def __eq__(self, other):
if isinstance(self.const_val, np.ndarray): if isinstance(self.const_val, np.ndarray):
return self.type == other.type and (self.const_val == other.const_val).all() return self.type == other.type and (self.const_val == other.const_val).all()
......
...@@ -18,7 +18,18 @@ import weakref ...@@ -18,7 +18,18 @@ import weakref
from inspect import getcallargs, getmembers, isclass, ismethod from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain from itertools import chain
from types import FunctionType from types import FunctionType
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from megengine import tensor from megengine import tensor
...@@ -261,8 +272,8 @@ class _InsertExprs: ...@@ -261,8 +272,8 @@ class _InsertExprs:
def __enter__(self): def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True) self.use_sym_shape = set_symbolic_shape(True)
node_id, expr_id = self.root_graph._total_ids node_id, expr_id = self.root_graph._total_ids
Node.set_total_id(node_id) Node._set_next_id(node_id)
Expr.set_total_id(expr_id) Expr._set_next_id(expr_id)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(True) _set_convert_node_flag(True)
assert active_module_tracer() is None assert active_module_tracer() is None
...@@ -341,18 +352,53 @@ class _InsertExprs: ...@@ -341,18 +352,53 @@ class _InsertExprs:
insert_index += 1 insert_index += 1
self.graph._used_names.update(self.global_scope._used_names) self.graph._used_names.update(self.global_scope._used_names)
self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id()) self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id())
self.root_graph.inputs[0].owner._update_ref() self.root_graph.inputs[0].owner._update_ref()
return True return True
class InternalGraph: class InternalGraph:
r"""``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. r"""``InternalGraph`` is the main data structure used in the TracedModule.
It is used to represent the execution procedure of Module's forward method.
For example, the following code
.. code-block::
import megengine.random as rand
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
class MyModule(M.Module):
def __init__(self):
super().__init__()
self.param = rand.normal(size=(3, 4))
self.linear = M.Linear(4, 5)
def forward(self, x):
return F.relu(self.linear(x + self.param))
net = MyModule()
inp = F.zeros(shape = (3, 4))
traced_module = tm.trace_module(net, inp)
Will produce the following ``InternalGraph``::
print(traced_module.graph)
Attributes: .. code-block:: text
_exprs: List of Exprs in order of execution
_inputs: Input Nodes of InternalGraph MyModule.Graph (self, x) {
_outputs: Output Nodes of InternalGraph %2: linear = getattr(self, "linear") -> (Linear)
%3: param = getattr(self, "param") -> (Tensor)
%4: add_out = x.__add__(param, )
%5: linear_out = linear(add_out, )
%6: relu_out = nn.relu(linear_out, )
return relu_out
}
""" """
_exprs = None # type: List[Expr] _exprs = None # type: List[Expr]
...@@ -394,44 +440,154 @@ class InternalGraph: ...@@ -394,44 +440,154 @@ class InternalGraph:
return name return name
@property @property
def inputs(self): def inputs(self) -> List[Node]:
r"""Get the list of input Nodes of this graph.
Returns:
A list of ``Node``.
"""
return self._inputs return self._inputs
@property @property
def outputs(self): def outputs(self) -> List[Node]:
r"""Get the list of output Nodes of this graph.
Returns:
A list of Node.
"""
return self._outputs return self._outputs
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this graph.
Returns:
An ``InternalGraph``.
"""
if self._top_graph: if self._top_graph:
return self._top_graph() return self._top_graph()
return None return None
def exprs(self, recursive=True): def exprs(self, recursive=True):
r"""Get the Exprs that constitute this graph.
Args:
recursive: whether to get the Exprs in the subgraph.
Default: True
Returns:
A ``ExprFilter`` containing all Exprs of this graph.
"""
return ExprFilter(_expr_iter(self, recursive)) return ExprFilter(_expr_iter(self, recursive))
def nodes(self, recursive=True): def nodes(self, recursive=True):
r"""Get the Nodes that constitute this graph.
Args:
recursive: whether to get the Nodes in the subgraph.
Default: True
Returns:
A ``NodeFilter`` containing all Nodes of this graph.
"""
return NodeFilter(_node_iter(self, recursive)) return NodeFilter(_node_iter(self, recursive))
def get_function_by_type(self, func: Callable = None, recursive=True): def get_function_by_type(self, func: Callable = None, recursive=True):
r"""Filter Exprs by the type of ``CallFunction``.
Args:
func: a built-in function, such as ``F.relu``.
recursive: whether to get the Exprs in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.ExprFilterCallFunction`.
"""
return self.exprs(recursive).call_function(func) return self.exprs(recursive).call_function(func)
def get_method_by_type(self, method: str = None, recursive=True): def get_method_by_type(self, method: str = None, recursive=True):
r"""Filter Exprs by the type of ``CallMethod``.
Args:
method: a method string, such as "__add__".
recursive: whether to get the Exprs in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.ExprFilterCallMethod`.
"""
return self.exprs(recursive).call_method(method) return self.exprs(recursive).call_method(method)
def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
r"""Filter Exprs by their ``id``.
Args:
expr_id: a list of :class:`int`.
recursive: whether to get the Exprs in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.ExprFilterExprId`.
"""
return self.exprs(recursive).expr_id(expr_id) return self.exprs(recursive).expr_id(expr_id)
def get_module_by_type(self, module_cls: Module, recursive=True): def get_module_by_type(self, module_cls: Module, recursive=True):
r"""Filter Nodes by the ``module_type`` of ``ModuleNode``.
Args:
module_cls: a subclass of :class:`~.Module`.
recursive: whether to get the Nodes in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.NodeFilterType`.
"""
assert issubclass(module_cls, Module) assert issubclass(module_cls, Module)
return self.nodes(recursive).type(module_cls, ModuleNode) return self.nodes(recursive).type(module_cls)
def get_node_by_id(self, node_id: List[int] = None, recursive=True): def get_node_by_id(self, node_id: List[int] = None, recursive=True):
r"""Filter Nodes by their ``id``.
The ``id`` of the ``Node`` can be obtained by the following code
.. code-block::
# node : Node
print("{:i}".format(node))
print(node.__format__("i"))
# graph : InternalGraph
print("{:i}".format(graph))
print(graph.__format__("i"))
Args:
node_id: a list of :class:`int`.
recursive: whether to get the Nodes in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.NodeFilterNodeId`.
"""
return self.nodes(recursive).node_id(node_id) return self.nodes(recursive).node_id(node_id)
def get_node_by_name( def get_node_by_name(
self, name: str = None, ignorecase: bool = True, recursive=True self, name: str = None, ignorecase: bool = True, recursive=True
): ):
r"""Filter Nodes by their full name.
The full name of the ``Node`` can be obtained by the following code
.. code-block::
# node : Node
print("{:p}".format(node))
print(node.__format__("p"))
# graph : InternalGraph
print("{:p}".format(graph))
print(graph.__format__("p"))
Args:
name: a string in glob syntax that can contain ``?`` and
``*`` to match a single or arbitrary characters.
ignorecase: whether to ignroe case.
Default: True
recursive: whether to get the Nodes in the subgraph.
Default: True
Returns:
A :class:`~.TracedModule.NodeFilterName`.
"""
return self.nodes(recursive).name(name, ignorecase) return self.nodes(recursive).name(name, ignorecase)
def _add_input(self, i): def _add_input(self, i):
...@@ -490,6 +646,13 @@ class InternalGraph: ...@@ -490,6 +646,13 @@ class InternalGraph:
o._orig_name = "{}{}".format(module_name, o._orig_name) o._orig_name = "{}{}".format(module_name, o._orig_name)
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
r"""Get the dependent Exprs of the ``nodes``.
Args:
nodes: a list of :class:`Node`.
Returns:
A list of dependent :class:`Expr`.
"""
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
nodes = (nodes,) nodes = (nodes,)
ret = list() ret = list()
...@@ -560,11 +723,22 @@ class InternalGraph: ...@@ -560,11 +723,22 @@ class InternalGraph:
self._inputs[:] = formal_node_inputs self._inputs[:] = formal_node_inputs
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
# return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:] return formal_node_inputs[1:]
def add_input_node(self, shape, dtype="float32", name="args"): def add_input_node(
self, shape: Tuple[int], dtype: str = "float32", name: str = "args"
):
r"""Add an input node to the graph.
The new Node will be the last of the positional arguments.
Args:
shape: the shape of the new input Node.
dtype: the dtype of the new input Node.
Default: float32
name: the name of the new input Node. When the name is used in the graph,
a suffix will be added to it.
"""
forma_mnode = self.inputs[0] forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_node actual_mnodes = forma_mnode.actual_node
...@@ -613,18 +787,63 @@ class InternalGraph: ...@@ -613,18 +787,63 @@ class InternalGraph:
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
# return formal_inp_node, actual_inp_nodes
return formal_inp_node return formal_inp_node
def reset_outputs(self, outputs): def reset_outputs(self, outputs):
r"""Reset the output Nodes of the graph.
.. note::
This method only supports resetting the output of graphs
that do not have a parent graph.
Args:
outputs: an object which inner element is Node. Support tuple, list
dict, etc.
For example, the following code
.. code-block::
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
class MyModule(M.Module):
def forward(self, x):
x = x + 1
return x
net = MyModule()
inp = F.zeros(shape = (1, ))
traced_module = tm.trace_module(net, inp)
graph = traced_module.graph
inp_node = graph.inputs[1]
out_node = graph.outputs[0]
graph.reset_outputs((out_node, {"input": inp_node}))
out = traced_module(inp)
Will produce the following ``InternalGraph`` and ``out``::
print(graph)
print(out)
.. code-block:: text
MyModule.Graph (self, x) {
%2: add_out = x.__add__(1, )
return add_out, x
}
(Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)})
"""
outputs, out_def = tree_flatten( outputs, out_def = tree_flatten(
outputs, is_leaf=lambda x: isinstance(x, TensorNode), outputs, is_leaf=lambda x: isinstance(x, TensorNode),
) )
forma_mnode = self.inputs[0] forma_mnode = self.inputs[0]
moudle = forma_mnode.owner moudle = forma_mnode.owner
assert moudle._is_top, "reset_outputs only support the top-level graph" assert moudle._is_top, "reset_outputs only support the top graph"
actual_mnodes = forma_mnode.actual_node actual_mnodes = forma_mnode.actual_node
call_nodes = [] call_nodes = []
...@@ -657,10 +876,53 @@ class InternalGraph: ...@@ -657,10 +876,53 @@ class InternalGraph:
return actual_nodes return actual_nodes
def add_output_node(self, node: TensorNode): def add_output_node(self, node: TensorNode):
r"""Add an output node to the Graph.
The Graph output will become a ``tuple`` after calling ``add_output_node``.
The first element of the ``tuple`` is the original output, and the second
is the ``node``.
For example, the following code
.. code-block::
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
class MyModule(M.Module):
def forward(self, x):
x = x + 1
return x
net = MyModule()
inp = F.zeros(shape = (1, ))
traced_module = tm.trace_module(net, inp)
graph = traced_module.graph
inp_node = graph.inputs[1]
out_node = graph.outputs[0]
graph.add_output_node(inp_node)
graph.add_output_node(out_node)
out = traced_module(inp)
Will produce the following ``InternalGraph`` and ``out``::
print(graph)
print(out)
.. code-block:: text
MyModule.Graph (self, x) {
%2: add_out = x.__add__(1, )
return add_out, x, add_out
}
((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0))
"""
forma_mnode = self.inputs[0] forma_mnode = self.inputs[0]
moudle = forma_mnode.owner moudle = forma_mnode.owner
assert moudle._is_top, "add_output_node only support the top-level graph" assert moudle._is_top, "add_output_node only support the top graph"
actual_mnodes = forma_mnode.actual_node actual_mnodes = forma_mnode.actual_node
call_nodes = [] call_nodes = []
...@@ -703,11 +965,33 @@ class InternalGraph: ...@@ -703,11 +965,33 @@ class InternalGraph:
return actual_out_nodes return actual_out_nodes
def insert_exprs(self, expr: Optional[Expr] = None): def insert_exprs(self, expr: Optional[Expr] = None):
r"""Initialize the trace mode and insertion position.
When used within a 'with' statement, this will temporary set the trace mode and
then restore normal mode when the with statement exits::
with graph.insert_exprs(e): # set the trace mode
... # trace function or module
... # inert exprs into graph and resotre normal mode
Args:
expr: the ``expr`` after which to insert. If None, the insertion position will be
automatically set based on the input node.
Returns:
A resource manager that will initialize trace mode on ``__enter__`` and
restore normal mode on ``__exit__``.
"""
if expr is not None: if expr is not None:
assert expr.top_graph == self, "Expr to insert after is not in graph." assert expr.top_graph == self, "Expr to insert after is not in graph."
return _InsertExprs(self, expr) return _InsertExprs(self, expr)
def replace_node(self, repl_dict: Dict[Node, Node]): def replace_node(self, repl_dict: Dict[Node, Node]):
r"""Replace the Nodes in the graph.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes.
"""
while repl_dict: while repl_dict:
node, repl_node = repl_dict.popitem() node, repl_node = repl_dict.popitem()
assert type(node) == type( assert type(node) == type(
...@@ -746,7 +1030,7 @@ class InternalGraph: ...@@ -746,7 +1030,7 @@ class InternalGraph:
n.inputs[idx] = repl_node n.inputs[idx] = repl_node
def compile(self): def compile(self):
"""Delete unused expr.""" r"""Delete unused expr."""
dep_exprs = self.get_dep_exprs(self.outputs) dep_exprs = self.get_dep_exprs(self.outputs)
i = 0 i = 0
while i < len(self._exprs): while i < len(self._exprs):
...@@ -804,7 +1088,12 @@ class InternalGraph: ...@@ -804,7 +1088,12 @@ class InternalGraph:
return list(node2value[i][0] for i in self._outputs) return list(node2value[i][0] for i in self._outputs)
def eval(self, *inputs): def eval(self, *inputs: Tuple[Tensor]):
r"""Call this method to execute the graph.
Args:
inputs: the tensors corresponding to the ``graph.inputs[1:]``.
"""
assert len(inputs) == len(self._inputs) - 1 assert len(inputs) == len(self._inputs) - 1
inp = [self._inputs[0].owner] + list(inputs) inp = [self._inputs[0].owner] + list(inputs)
return self.interpret(*inp) return self.interpret(*inp)
...@@ -813,7 +1102,7 @@ class InternalGraph: ...@@ -813,7 +1102,7 @@ class InternalGraph:
return self.__format__() return self.__format__()
def __format__(self, format_spec: str = "") -> str: def __format__(self, format_spec: str = "") -> str:
saved_format_spec = Node.set_format_spec(format_spec) saved_format_spec = Node._set_format_spec(format_spec)
name = "" name = ""
if self._name: if self._name:
name = "%s.Graph" % self._name name = "%s.Graph" % self._name
...@@ -823,7 +1112,7 @@ class InternalGraph: ...@@ -823,7 +1112,7 @@ class InternalGraph:
"\n\t".join("{}".format(str(i)) for i in self._exprs), "\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs), ", ".join(str(i) for i in self._outputs),
) )
Node.set_format_spec(saved_format_spec) Node._set_format_spec(saved_format_spec)
return res return res
def __getstate__(self): def __getstate__(self):
...@@ -1010,7 +1299,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1010,7 +1299,7 @@ class TracedModuleBuilder(NodeMixin):
for _, g in self._argdef_graph_map.items(): for _, g in self._argdef_graph_map.items():
g.compile() g.compile()
if self._is_top: if self._is_top:
g._total_ids = (Node.get_total_id(), Expr.get_total_id()) g._total_ids = (Node._get_next_id(), Expr._get_next_id())
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
...@@ -1298,59 +1587,106 @@ class _node_iter: ...@@ -1298,59 +1587,106 @@ class _node_iter:
class BaseFilter: class BaseFilter:
def __init__(self, expr_iter: Iterable): r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc."""
self._iter = expr_iter
def __init__(self, iter: Iterable):
self._iter = iter
def __iter__(self): def __iter__(self):
return iter(self._iter) return iter(self._iter)
def as_list(self): def as_list(self):
r"""Consume this iterator and return its content as a list.
Returns:
A list of ``Node`` or ``Expr``.
"""
return list(self) return list(self)
def as_dict(self): def as_dict(self):
r"""Construct an ordered dict to map from ``id`` to objects in this iterator.
Returns:
An :class:`OrderedDict`.
"""
return collections.OrderedDict((i._id, i) for i in self) return collections.OrderedDict((i._id, i) for i in self)
def as_unique(self): def as_unique(self):
"""Assert that this iterator yields only one ``Node`` or ``Expr`` and return it.
Rerurns:
A ``Node`` or ``Expr``.
"""
rst = self.as_list() rst = self.as_list()
assert len(rst) == 1, "{} elements found".format(len(rst)) assert len(rst) == 1, "{} elements found".format(len(rst))
(expr,) = self (elem,) = self
return expr return elem
def as_count(self): def as_count(self):
r"""Consume this iterator and get the number of elements."""
return sum(1 for _ in self) return sum(1 for _ in self)
class ExprFilter(BaseFilter): class ExprFilter(BaseFilter):
"""Filter on Expr iterator.
This class is an iterator of :class:`.Expr` objects and multiple
filtering conditions and mappers can be chained.
"""
def call_function(self, func): def call_function(self, func):
r"""Filter by specific ``CallFunction.func``.
See :meth:`~.InternalGraph.get_function_by_type` for details.
"""
return ExprFilterCallFunction(self, func) return ExprFilterCallFunction(self, func)
def call_method(self, method): def call_method(self, method):
r"""Filter by specific ``CallMethod.method``.
See :meth:`~.InternalGraph.get_function_by_type` for details.
"""
return ExprFilterCallMethod(self, method) return ExprFilterCallMethod(self, method)
def expr_id(self, expr_id: List[int]): def expr_id(self, expr_id: List[int]):
r"""Filter Exprs by their ``id``.
See :meth:`~.InternalGraph.get_function_by_type` for details.
"""
return ExprFilterExprId(self, expr_id) return ExprFilterExprId(self, expr_id)
class NodeFilter(BaseFilter): class NodeFilter(BaseFilter):
def type(self, owner_type, node_type): """Filter on Node iterator.
return NodeFilterType(self, owner_type, node_type) This class is an iterator of :class:`.Node` objects and multiple
filtering conditions and mappers can be chained.
"""
def type(self, owner_type):
r"""Filter by specific Module type.
See :meth:`~.InternalGraph.get_module_by_type` for details.
"""
return NodeFilterType(self, owner_type)
def node_id(self, node_id: List[int]): def node_id(self, node_id: List[int]):
r"""Filter Nodes by their ``id``.
See :meth:`~.InternalGraph.get_node_by_id` for details.
"""
return NodeFilterNodeId(self, node_id) return NodeFilterNodeId(self, node_id)
def name(self, name: str, ignorecase: bool = True): def name(self, name: str, ignorecase: bool = True):
r"""Filter Nodes by their full name.
See :meth:`~.InternalGraph.get_node_by_name` for details.
"""
return NodeFilterName(self, name, ignorecase) return NodeFilterName(self, name, ignorecase)
class NodeFilterType(NodeFilter): class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type): """See :meth:`~.InternalGraph.get_module_by_type`"""
def __init__(self, expr_iter, owner_type):
super().__init__(expr_iter) super().__init__(expr_iter)
self.owner_type = owner_type self.owner_type = owner_type
self.node_type = node_type
def __iter__(self): def __iter__(self):
for node in self._iter: for node in self._iter:
if not isinstance(node, self.node_type): if not isinstance(node, ModuleNode):
continue continue
if not hasattr(node, "owner"): if not hasattr(node, "owner"):
continue continue
...@@ -1359,6 +1695,8 @@ class NodeFilterType(NodeFilter): ...@@ -1359,6 +1695,8 @@ class NodeFilterType(NodeFilter):
class NodeFilterNodeId(NodeFilter): class NodeFilterNodeId(NodeFilter):
"""See :meth:`~.InternalGraph.get_node_by_id`"""
def __init__(self, expr_iter, node_id: List[int]): def __init__(self, expr_iter, node_id: List[int]):
super().__init__(expr_iter) super().__init__(expr_iter)
if not isinstance(node_id, Sequence): if not isinstance(node_id, Sequence):
...@@ -1372,6 +1710,8 @@ class NodeFilterNodeId(NodeFilter): ...@@ -1372,6 +1710,8 @@ class NodeFilterNodeId(NodeFilter):
class NodeFilterName(NodeFilter): class NodeFilterName(NodeFilter):
"""See :meth:`~.InternalGraph.get_node_by_name`"""
_re = None _re = None
def __init__(self, node_iter, pattern, ignorecase): def __init__(self, node_iter, pattern, ignorecase):
...@@ -1399,6 +1739,8 @@ class NodeFilterName(NodeFilter): ...@@ -1399,6 +1739,8 @@ class NodeFilterName(NodeFilter):
class ExprFilterCallFunction(ExprFilter): class ExprFilterCallFunction(ExprFilter):
"""See :meth:`~.InternalGraph.get_function_by_type`"""
def __init__(self, expr_iter, func: Callable = None): def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter) super().__init__(expr_iter)
self.func = func self.func = func
...@@ -1412,6 +1754,8 @@ class ExprFilterCallFunction(ExprFilter): ...@@ -1412,6 +1754,8 @@ class ExprFilterCallFunction(ExprFilter):
class ExprFilterCallMethod(ExprFilter): class ExprFilterCallMethod(ExprFilter):
"""See :meth:`~.InternalGraph.get_method_by_type`"""
def __init__(self, expr_iter, method: str = None): def __init__(self, expr_iter, method: str = None):
super().__init__(expr_iter) super().__init__(expr_iter)
self.method = method self.method = method
...@@ -1425,6 +1769,8 @@ class ExprFilterCallMethod(ExprFilter): ...@@ -1425,6 +1769,8 @@ class ExprFilterCallMethod(ExprFilter):
class ExprFilterExprId(ExprFilter): class ExprFilterExprId(ExprFilter):
"""See :meth:`~.InternalGraph.get_expr_by_id`"""
def __init__(self, expr_iter, expr_id: List[int]): def __init__(self, expr_iter, expr_id: List[int]):
super().__init__(expr_iter) super().__init__(expr_iter)
if not isinstance(expr_id, Sequence): if not isinstance(expr_id, Sequence):
...@@ -1438,8 +1784,16 @@ class ExprFilterExprId(ExprFilter): ...@@ -1438,8 +1784,16 @@ class ExprFilterExprId(ExprFilter):
class TracedModule(Module): class TracedModule(Module):
r"""`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.""" r"""``TracedModule`` is the Module created by tracing normal module.
It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule``
will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs``
and interpret it.
.. note::
``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module`
for more details.
"""
# m_node = None # type: ModuleNode # m_node = None # type: ModuleNode
argdef_graph_map = None argdef_graph_map = None
argdef_outdef_map = None argdef_outdef_map = None
...@@ -1475,19 +1829,97 @@ class TracedModule(Module): ...@@ -1475,19 +1829,97 @@ class TracedModule(Module):
return outputs return outputs
def set_watch_points(self, nodes): def set_watch_points(self, nodes):
r"""Initialize the :attr:`~.TracedModule.watch_points`.
You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime.
Args:
nodes: a list of ``Node``.
For example, the following code
.. code-block::
import megengine.module as M
import megengine as mge
import megengine.traced_module as tm
class MyModule(M.Module):
def forward(self, x):
x = x + 1 + 2
return x
net = MyModule()
inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_watch_points(add_1_node)
out = traced_module(inp)
Will get the following ``watch_node_value``::
print(traced_module.watch_node_value)
.. code-block:: text
{add_out: Tensor([1.], device=xpux:0)}
"""
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
nodes = [nodes] nodes = [nodes]
self.watch_points = nodes self.watch_points = nodes
if nodes:
nodes[0].top_graph._watch_point = []
for n in nodes: for n in nodes:
n.top_graph._watch_point.append(n) n.top_graph._watch_point.append(n)
def clear_watch_points(self): def clear_watch_points(self):
r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`.
"""
for n in self.watch_points: for n in self.watch_points:
n.top_graph._watch_point = [] n.top_graph._watch_point = []
self.watch_points = [] self.watch_points = []
self.watch_node_value = {} self.watch_node_value = {}
def set_end_points(self, nodes): def set_end_points(self, nodes: Sequence[Node]):
r"""Initialize the :attr:`~.TracedModule.end_points`.
When all the ``nodes`` are generated, the Module will stop execution and return directly.
Args:
nodes: a list of ``Node``.
For example, the following code
.. code-block::
import megengine.module as M
import megengine as mge
import megengine.traced_module as tm
class MyModule(M.Module):
def forward(self, x):
x = x + 1 + 2
return x
net = MyModule()
inp = mge.Tensor([0])
traced_module = tm.trace_module(net, inp)
add_1_node = traced_module.graph.get_node_by_id(2).as_unique()
traced_module.set_end_points(add_1_node)
out = traced_module(inp)
Will get the following ``out``::
print(out)
.. code-block:: text
[Tensor([1.], device=xpux:0)]
"""
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
nodes = [nodes] nodes = [nodes]
self.end_points = nodes self.end_points = nodes
...@@ -1497,12 +1929,16 @@ class TracedModule(Module): ...@@ -1497,12 +1929,16 @@ class TracedModule(Module):
n.top_graph._end_point.append(n) n.top_graph._end_point.append(n)
def clear_end_points(self): def clear_end_points(self):
r"""Clear the :attr:`~.TracedModule.end_points`.
"""
for n in self.end_points: for n in self.end_points:
n.top_graph._end_point = [] n.top_graph._end_point = []
self.end_points = [] self.end_points = []
@property @property
def graph(self) -> InternalGraph: def graph(self) -> InternalGraph:
"""Return the ``InternalGraph`` of this ``TracedModule``
"""
if self._is_top: if self._is_top:
self._update_ref() self._update_ref()
assert len(self.argdef_graph_map) == 1 assert len(self.argdef_graph_map) == 1
...@@ -1559,9 +1995,10 @@ class TracedModule(Module): ...@@ -1559,9 +1995,10 @@ class TracedModule(Module):
obj._update_ref(mnode_map, graph) obj._update_ref(mnode_map, graph)
def flatten(self): def flatten(self):
r"""Get a new module, which eliminates ``GetAttr`` and has no hierarchy. r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy.
:return: :class:`TracedModule` Retruns:
A new :class:`TracedModule`.
""" """
new_module = copy.deepcopy(self) new_module = copy.deepcopy(self)
assert active_module_tracer() is None assert active_module_tracer() is None
...@@ -1690,16 +2127,35 @@ def cpp_apply_module_trace(opdef, *args): ...@@ -1690,16 +2127,35 @@ def cpp_apply_module_trace(opdef, *args):
def register_as_builtin(mod_cls: Type[Module]) -> None: def register_as_builtin(mod_cls: Type[Module]) -> None:
r"""Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
Args: Args:
mod_cls: the Module class which will be threated as builtin module in tracing mod_cls: the module class which will be treated as builtin module in tracing.
""" """
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)
def wrap(func: Callable): def wrap(func: Callable):
r"""Call this function to register func as a builtin function.""" r"""Call this function to register ``func`` as a builtin function.
This function can be called at module-level scope to register ``func`` as a builtin function.
A builtin function will be converted to a :class:`CallFunction` Expr in tracing::
def my_func(x, y):
return x + y
import megengine.traced_module as tm
tm.wrap(my_func)
This function can also equivalently be used as a decorator::
@tm.wrap
def my_func(x, y):
return x + y
Args:
func: the function of the global function to insert into the graph when it's called.
"""
assert callable(func), "func must be a callable" assert callable(func), "func must be a callable"
assert hasattr(func, "__code__") assert hasattr(func, "__code__")
fn_name = func.__code__.co_name fn_name = func.__code__.co_name
...@@ -1739,13 +2195,15 @@ def _register_all_builtin_module(): ...@@ -1739,13 +2195,15 @@ def _register_all_builtin_module():
module_tracer.register_as_builtin(TM_FakeQuant) module_tracer.register_as_builtin(TM_FakeQuant)
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: def trace_module(
r"""Traces module ``mod`` and returns corresponding TracedModule. mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any]
) -> TracedModule:
r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`.
Args: Args:
mod: the module will be converted to TracedModule mod: the module will be converted to :class:`TracedModule`.
input: the positional arguments passed to forward method of ``mod`` args: the positional arguments passed to forward method of ``mod``.
kwargs: the keyword arguments passed to forward method of ``mod`` kwargs: the keyword arguments passed to forward method of ``mod``.
""" """
assert active_module_tracer() is None assert active_module_tracer() is None
assert isinstance(mod, Module) assert isinstance(mod, Module)
...@@ -1756,7 +2214,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -1756,7 +2214,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
module_tracer(_wrapped_function, _init_id2name(mod, "self")) module_tracer(_wrapped_function, _init_id2name(mod, "self"))
) )
for cls in [Expr, Node]: for cls in [Expr, Node]:
cls.set_total_id(0) cls._set_next_id(0)
with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph(name="") global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册