diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index ef453fc7462d0a0c1b14910380cd4c7541f83da5..3cb920fcbfd2818f1adbbf705d36f4496944c304 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -130,3 +130,4 @@ import megengine.optimizer import megengine.quantization import megengine.random import megengine.utils +import megengine.traced_module diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 0247b861d0b6d3364a38d5a4fd70c586192fd3d6..df0b16165a694ac6c1477f092c62b688acff668a 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str): class Expr: - """``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.""" - - __total_id = 0 + r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, + ``GetAttr``, ``Input``, ``Constant``) on ``Node``. + """ + inputs = None # type: List[Node] + r"""The input Nodes of this Expr.""" outputs = None # type: List[Node] + r"""The output Nodes of this Expr.""" const_val = None # type: List[Any] + r"""The non-tensor object in the input of the operation.""" arg_def = None # type: TreeDef + r"""The :class:`TreeDef` used to reconstruct the input of the operation.""" out_def = None # type: TreeDef + r"""The :class:`TreeDef` used to reconstruct the output of the operation.""" _top_graph = None # type: weakref.ReferenceType + __total_id = 0 def __init__(self) -> None: self._id = Expr.__total_id @@ -125,6 +132,11 @@ class Expr: return inputs, {} def replace_inputs(self, repl_dict: Dict[Node, Node]): + r"""Replace the input Nodes of this Expr. + + Args: + repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes. + """ while repl_dict: node, repl_node = repl_dict.popitem() assert type(node) == type(repl_node) @@ -147,16 +159,19 @@ class Expr: @property def kwargs(self): + r"""Get the the keyword arguments of the operation corresponding to this Expr.""" _, kwargs = self.unflatten_args(self.inputs) return kwargs @property def args(self): + r"""Get the the positional arguments of the operation corresponding to this Expr.""" args, _ = self.unflatten_args(self.inputs) return args @property def top_graph(self): + r"""Get the parent graph of this Expr.""" if self._top_graph: return self._top_graph() return None @@ -168,17 +183,18 @@ class Expr: return state @classmethod - def get_total_id(cls): + def _get_next_id(cls): return cls.__total_id @classmethod - def set_total_id(cls, id: int = 0): + def _set_next_id(cls, id: int = 0): assert isinstance(id, int) cls.__total_id = id # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): + r"""A fake Expr which is used to mark the input of graph.""" name = None def __init__(self, name=None, type=None, orig_name=None): @@ -204,13 +220,15 @@ class Input(Expr): return expr.outputs[0] def __repr__(self): - return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name) + return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) # expr: outputs = getattr(inputs[0], self.name) class GetAttr(Expr): - name = None + r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy.""" + name = None + r"""name: the qualified name of the attribute to be retrieved.""" def __init__(self, module, name, type=None, orig_name=None): super().__init__() assert isinstance(module, ModuleNode) @@ -251,6 +269,13 @@ class GetAttr(Expr): # expr: outputs = inputs[0].__call__(*inputs[1:]) class CallMethod(Expr): + r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``. + + Args: + node: the Node to be called. + method: the method name. + Default: "__call__" + """ def __init__(self, node, method="__call__"): super().__init__() if isinstance(node, type): @@ -320,8 +345,12 @@ class CallMethod(Expr): # expr: outputs = apply(self.opdef, *inputs) class Apply(Expr): - opdef = None + r"""``Apply`` represents a call to :func:`apply`. + Args: + opdef: the applied :class:`OpDef`. + """ + opdef = None def __init__(self, opdef): super().__init__() assert isinstance(opdef, OpDef) @@ -388,6 +417,11 @@ class Apply(Expr): class CallFunction(Expr): + r"""``CallFunction`` represents a call to a built-in function. + + Args: + func: a built-in function. + """ def __init__(self, func): super().__init__() assert isinstance(func, Callable) @@ -425,7 +459,14 @@ class CallFunction(Expr): # expr outputs = self.value class Constant(Expr): + r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module. + + Args: + c: a const Tensor or Module. + name: the name of output Node. + """ value = None + r"""The const Tensor or Module""" # TODO: constant cache to reduce the size of dumped model _constant_cache = {} diff --git a/imperative/python/megengine/traced_module/fake_quant.py b/imperative/python/megengine/traced_module/fake_quant.py index 8dd29aa0bbd9a801b5a3978cfb198b352db98c93..0e355f56d0ed997ddcd7b6174df1057575a6c5c9 100644 --- a/imperative/python/megengine/traced_module/fake_quant.py +++ b/imperative/python/megengine/traced_module/fake_quant.py @@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor class FakeQuantize(_FakeQuantize, QParamsModuleMixin): + r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`.""" + def __init__( self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs ): @@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin): return self.qparams def set_qparams(self, qparams: QParams): - r""" + r"""Initialize :attr:`~.FakeQuantize.qparams`. + Args: - qparams: used to set initial scale. + qparams: used to set initial ``scale`` and ``zero_point``. """ if qparams.scale is None: raise AssertionError("Can not get an initialized scale") diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 2ae32e70b02d32d2106ef767c29892b1b651981d..056043c7e886114a568b0251e68ac143bd7bf1ef 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type import numpy +from .. import get_logger from ..core._imperative_rt.core2 import Tensor as RawTensor from ..module import Module from ..tensor import Tensor +logger = get_logger(__name__) -class Node: - r"""``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. - They are inputs/outputs of Expr(the operations on variables). - Args: - expr: the Expr which produces the node - name: the name of the node +class Node: + r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method. + They are inputs/outputs of Expr (the operations on variables). """ - expr = None - __total_id = 0 - _id = None + expr = None # type: Expr + r"""The Expr which produces the Node.""" + __total_id = 0 # type: int + _id = None # type: int _top_graph = None # type: weakref.ReferenceType - _name = None - _orig_name = None - _format_spec = "" + _name = None # type: str + _orig_name = None # type: str + _format_spec = "" # type: str - def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): + def __init__(self, expr: "Expr", name: str, orig_name: str): self.expr = expr self.users = [] # List[Expr] self._id = Node.__total_id @@ -73,32 +73,51 @@ class Node: else: return name if name else ("%d" % self._id) + @property + def name(self): + r"""Return the name of this Node.""" + return self._name + + @name.setter + def name(self, new_name: str): + graph = self.top_graph + assert graph is not None, "The parent graph of this Node cannot be None." + assert new_name not in graph._used_names, ( + "The name(%s) is already in use. Please try a different one again." + % (new_name) + ) + new_name = graph._create_unique_name(new_name) + self._name = new_name + self._orig_name = new_name + @property def top_graph(self): + r"""Get the parent graph of this Node.""" if self._top_graph: return self._top_graph() return None @classmethod - def set_format_spec(cls, str): + def _set_format_spec(cls, str): old_format_spec = cls._format_spec cls._format_spec = str return old_format_spec @classmethod - def get_total_id(cls): + def _get_next_id(cls): return cls.__total_id @classmethod - def set_total_id(cls, id: int = 0): + def _set_next_id(cls, id: int = 0): assert isinstance(id, int) cls.__total_id = id class ModuleNode(Node): r"""``ModuleNode`` represents the Module objects.""" - + module_type = Module # type: Type[Module] + r"""The type of the Module correspending to the ModuleNode.""" _owner = None # type: weakref.ReferenceType def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): @@ -116,6 +135,11 @@ class ModuleNode(Node): @property def owner(self): + r"""Get the ``Module`` corresponding to this ``ModuleNode``. + + Returns: + An :calss:`~.Module`. + """ if self._owner: return self._owner() return None @@ -145,6 +169,7 @@ class TensorNode(Node): @property def shape(self): + r"""Get the shape of this Node.""" return self._shape @shape.setter @@ -153,6 +178,7 @@ class TensorNode(Node): @property def dtype(self): + r"""Get the dtype of this Node.""" return self._dtype @dtype.setter @@ -161,6 +187,7 @@ class TensorNode(Node): @property def device(self): + r"""Get the device of this Node pointed Tensor.""" return self._device @device.setter @@ -169,6 +196,7 @@ class TensorNode(Node): @property def qparams(self): + r"""Get the :calss:`QParams` of this Node.""" return self._qparams @qparams.setter @@ -177,10 +205,16 @@ class TensorNode(Node): @property def value(self): + r"""Get the bound Tensor of this Node.""" return self._value @value.setter def value(self, value): + r"""Bind a Tensor to this Node. + + Args: + value: A :class:`Tensor`. + """ if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: setattr(value, "_NodeMixin__node", None) self._value = value diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index e4f13e4bcef9116155bb9c39ca977fb55bd118c4..8be744e93bddfced88d2bfccd4a051ecbc519005 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -150,6 +150,9 @@ def tree_flatten( is_leaf: Callable = _is_leaf, is_const_leaf: Callable = _is_const_leaf, ): + r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used + to reconstruct the object. + """ if type(values) not in SUPPORTED_TYPE: assert is_leaf(values), values node = LeafDef(leaf_type(values)) @@ -169,6 +172,15 @@ def tree_flatten( class TreeDef: + r"""A ``TreeDef`` represents the structure of a pytree. + + Args: + type: the type of root Node of the pytree. + aux_data: some const data that is useful in unflattening the pytree. + children_defs: ``TreeDef`` for each child of the root Node. + num_leaves: the number of leaves. + """ + def __init__(self, type, aux_data, children_defs): self.type = type self.aux_data = aux_data @@ -176,6 +188,9 @@ class TreeDef: self.num_leaves = sum(ch.num_leaves for ch in children_defs) def unflatten(self, leaves): + r"""Given a list of values and a ``TreeDef``, builds a object. + This is the inverse operation of ``tree_flatten``. + """ assert len(leaves) == self.num_leaves start = 0 children = [] @@ -196,13 +211,10 @@ class TreeDef: ) ) - def __lt__(self, other): - return self.__hash__() < other.__hash__() - - def __gt__(self, other): - return self.__hash__() > other.__hash__() + def __ne__(self, other) -> bool: + return not self.__eq__(other) - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( self.type == other.type and self.aux_data == other.aux_data @@ -227,6 +239,9 @@ class LeafDef(TreeDef): assert isinstance(leaves[0], self.type), self.type return leaves[0] + def __ne__(self, other) -> bool: + return not self.__eq__(other) + def __eq__(self, other): if isinstance(self.const_val, np.ndarray): return self.type == other.type and (self.const_val == other.const_val).all() diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 56b88e871015bbb1b5ffba48f01f101a6b199ec4..c6a0fae3867033ec4a4c2e1ed07d89c9c16b40a7 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -18,7 +18,18 @@ import weakref from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain from types import FunctionType -from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from megengine import tensor @@ -261,8 +272,8 @@ class _InsertExprs: def __enter__(self): self.use_sym_shape = set_symbolic_shape(True) node_id, expr_id = self.root_graph._total_ids - Node.set_total_id(node_id) - Expr.set_total_id(expr_id) + Node._set_next_id(node_id) + Expr._set_next_id(expr_id) set_module_tracing() _set_convert_node_flag(True) assert active_module_tracer() is None @@ -341,18 +352,53 @@ class _InsertExprs: insert_index += 1 self.graph._used_names.update(self.global_scope._used_names) - self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id()) + self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id()) self.root_graph.inputs[0].owner._update_ref() return True class InternalGraph: - r"""``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. + r"""``InternalGraph`` is the main data structure used in the TracedModule. + It is used to represent the execution procedure of Module's forward method. + + For example, the following code + + .. code-block:: + + import megengine.random as rand + import megengine.functional as F + import megengine.module as M + + import megengine.traced_module as tm + + class MyModule(M.Module): + def __init__(self): + super().__init__() + self.param = rand.normal(size=(3, 4)) + self.linear = M.Linear(4, 5) + + def forward(self, x): + return F.relu(self.linear(x + self.param)) - Attributes: - _exprs: List of Exprs in order of execution - _inputs: Input Nodes of InternalGraph - _outputs: Output Nodes of InternalGraph + net = MyModule() + + inp = F.zeros(shape = (3, 4)) + traced_module = tm.trace_module(net, inp) + + Will produce the following ``InternalGraph``:: + + print(traced_module.graph) + + .. code-block:: text + + MyModule.Graph (self, x) { + %2: linear = getattr(self, "linear") -> (Linear) + %3: param = getattr(self, "param") -> (Tensor) + %4: add_out = x.__add__(param, ) + %5: linear_out = linear(add_out, ) + %6: relu_out = nn.relu(linear_out, ) + return relu_out + } """ _exprs = None # type: List[Expr] @@ -394,44 +440,154 @@ class InternalGraph: return name @property - def inputs(self): + def inputs(self) -> List[Node]: + r"""Get the list of input Nodes of this graph. + + Returns: + A list of ``Node``. + """ return self._inputs @property - def outputs(self): + def outputs(self) -> List[Node]: + r"""Get the list of output Nodes of this graph. + + Returns: + A list of Node. + """ return self._outputs @property def top_graph(self): + r"""Get the parent graph of this graph. + + Returns: + An ``InternalGraph``. + """ if self._top_graph: return self._top_graph() return None def exprs(self, recursive=True): + r"""Get the Exprs that constitute this graph. + + Args: + recursive: whether to get the Exprs in the subgraph. + Default: True + Returns: + A ``ExprFilter`` containing all Exprs of this graph. + """ return ExprFilter(_expr_iter(self, recursive)) def nodes(self, recursive=True): + r"""Get the Nodes that constitute this graph. + + Args: + recursive: whether to get the Nodes in the subgraph. + Default: True + Returns: + A ``NodeFilter`` containing all Nodes of this graph. + """ return NodeFilter(_node_iter(self, recursive)) def get_function_by_type(self, func: Callable = None, recursive=True): + r"""Filter Exprs by the type of ``CallFunction``. + + Args: + func: a built-in function, such as ``F.relu``. + recursive: whether to get the Exprs in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.ExprFilterCallFunction`. + """ return self.exprs(recursive).call_function(func) def get_method_by_type(self, method: str = None, recursive=True): + r"""Filter Exprs by the type of ``CallMethod``. + + Args: + method: a method string, such as "__add__". + recursive: whether to get the Exprs in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.ExprFilterCallMethod`. + """ return self.exprs(recursive).call_method(method) def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): + r"""Filter Exprs by their ``id``. + + Args: + expr_id: a list of :class:`int`. + recursive: whether to get the Exprs in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.ExprFilterExprId`. + """ return self.exprs(recursive).expr_id(expr_id) def get_module_by_type(self, module_cls: Module, recursive=True): + r"""Filter Nodes by the ``module_type`` of ``ModuleNode``. + + Args: + module_cls: a subclass of :class:`~.Module`. + recursive: whether to get the Nodes in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.NodeFilterType`. + """ assert issubclass(module_cls, Module) - return self.nodes(recursive).type(module_cls, ModuleNode) + return self.nodes(recursive).type(module_cls) def get_node_by_id(self, node_id: List[int] = None, recursive=True): + r"""Filter Nodes by their ``id``. + + The ``id`` of the ``Node`` can be obtained by the following code + + .. code-block:: + + # node : Node + print("{:i}".format(node)) + print(node.__format__("i")) + # graph : InternalGraph + print("{:i}".format(graph)) + print(graph.__format__("i")) + + Args: + node_id: a list of :class:`int`. + recursive: whether to get the Nodes in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.NodeFilterNodeId`. + """ return self.nodes(recursive).node_id(node_id) def get_node_by_name( self, name: str = None, ignorecase: bool = True, recursive=True ): + r"""Filter Nodes by their full name. + + The full name of the ``Node`` can be obtained by the following code + + .. code-block:: + + # node : Node + print("{:p}".format(node)) + print(node.__format__("p")) + # graph : InternalGraph + print("{:p}".format(graph)) + print(graph.__format__("p")) + + Args: + name: a string in glob syntax that can contain ``?`` and + ``*`` to match a single or arbitrary characters. + ignorecase: whether to ignroe case. + Default: True + recursive: whether to get the Nodes in the subgraph. + Default: True + Returns: + A :class:`~.TracedModule.NodeFilterName`. + """ return self.nodes(recursive).name(name, ignorecase) def _add_input(self, i): @@ -490,6 +646,13 @@ class InternalGraph: o._orig_name = "{}{}".format(module_name, o._orig_name) def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: + r"""Get the dependent Exprs of the ``nodes``. + + Args: + nodes: a list of :class:`Node`. + Returns: + A list of dependent :class:`Expr`. + """ if not isinstance(nodes, Sequence): nodes = (nodes,) ret = list() @@ -560,11 +723,22 @@ class InternalGraph: self._inputs[:] = formal_node_inputs moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) - - # return formal_node_inputs[1:], actual_nodes return formal_node_inputs[1:] - def add_input_node(self, shape, dtype="float32", name="args"): + def add_input_node( + self, shape: Tuple[int], dtype: str = "float32", name: str = "args" + ): + r"""Add an input node to the graph. + + The new Node will be the last of the positional arguments. + + Args: + shape: the shape of the new input Node. + dtype: the dtype of the new input Node. + Default: float32 + name: the name of the new input Node. When the name is used in the graph, + a suffix will be added to it. + """ forma_mnode = self.inputs[0] actual_mnodes = forma_mnode.actual_node @@ -613,18 +787,63 @@ class InternalGraph: moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) - - # return formal_inp_node, actual_inp_nodes return formal_inp_node def reset_outputs(self, outputs): + r"""Reset the output Nodes of the graph. + + .. note:: + + This method only supports resetting the output of graphs + that do not have a parent graph. + + Args: + outputs: an object which inner element is Node. Support tuple, list + dict, etc. + + For example, the following code + + .. code-block:: + + import megengine.functional as F + import megengine.module as M + import megengine.traced_module as tm + + class MyModule(M.Module): + def forward(self, x): + x = x + 1 + return x + + net = MyModule() + + inp = F.zeros(shape = (1, )) + traced_module = tm.trace_module(net, inp) + graph = traced_module.graph + inp_node = graph.inputs[1] + out_node = graph.outputs[0] + graph.reset_outputs((out_node, {"input": inp_node})) + out = traced_module(inp) + + Will produce the following ``InternalGraph`` and ``out``:: + + print(graph) + print(out) + + .. code-block:: text + + MyModule.Graph (self, x) { + %2: add_out = x.__add__(1, ) + return add_out, x + } + (Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)}) + """ outputs, out_def = tree_flatten( outputs, is_leaf=lambda x: isinstance(x, TensorNode), ) forma_mnode = self.inputs[0] moudle = forma_mnode.owner - assert moudle._is_top, "reset_outputs only support the top-level graph" + assert moudle._is_top, "reset_outputs only support the top graph" actual_mnodes = forma_mnode.actual_node call_nodes = [] @@ -657,10 +876,53 @@ class InternalGraph: return actual_nodes def add_output_node(self, node: TensorNode): + r"""Add an output node to the Graph. + + The Graph output will become a ``tuple`` after calling ``add_output_node``. + The first element of the ``tuple`` is the original output, and the second + is the ``node``. + + For example, the following code + + .. code-block:: + + import megengine.functional as F + import megengine.module as M + import megengine.traced_module as tm + + class MyModule(M.Module): + def forward(self, x): + x = x + 1 + return x + + net = MyModule() + + inp = F.zeros(shape = (1, )) + traced_module = tm.trace_module(net, inp) + graph = traced_module.graph + inp_node = graph.inputs[1] + out_node = graph.outputs[0] + graph.add_output_node(inp_node) + graph.add_output_node(out_node) + out = traced_module(inp) + + Will produce the following ``InternalGraph`` and ``out``:: + + print(graph) + print(out) + + .. code-block:: text + + MyModule.Graph (self, x) { + %2: add_out = x.__add__(1, ) + return add_out, x, add_out + } + ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0)) + """ forma_mnode = self.inputs[0] moudle = forma_mnode.owner - assert moudle._is_top, "add_output_node only support the top-level graph" + assert moudle._is_top, "add_output_node only support the top graph" actual_mnodes = forma_mnode.actual_node call_nodes = [] @@ -703,11 +965,33 @@ class InternalGraph: return actual_out_nodes def insert_exprs(self, expr: Optional[Expr] = None): + r"""Initialize the trace mode and insertion position. + + When used within a 'with' statement, this will temporary set the trace mode and + then restore normal mode when the with statement exits:: + + with graph.insert_exprs(e): # set the trace mode + ... # trace function or module + ... # inert exprs into graph and resotre normal mode + + Args: + expr: the ``expr`` after which to insert. If None, the insertion position will be + automatically set based on the input node. + + Returns: + A resource manager that will initialize trace mode on ``__enter__`` and + restore normal mode on ``__exit__``. + """ if expr is not None: assert expr.top_graph == self, "Expr to insert after is not in graph." return _InsertExprs(self, expr) def replace_node(self, repl_dict: Dict[Node, Node]): + r"""Replace the Nodes in the graph. + + Args: + repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes. + """ while repl_dict: node, repl_node = repl_dict.popitem() assert type(node) == type( @@ -746,7 +1030,7 @@ class InternalGraph: n.inputs[idx] = repl_node def compile(self): - """Delete unused expr.""" + r"""Delete unused expr.""" dep_exprs = self.get_dep_exprs(self.outputs) i = 0 while i < len(self._exprs): @@ -804,7 +1088,12 @@ class InternalGraph: return list(node2value[i][0] for i in self._outputs) - def eval(self, *inputs): + def eval(self, *inputs: Tuple[Tensor]): + r"""Call this method to execute the graph. + + Args: + inputs: the tensors corresponding to the ``graph.inputs[1:]``. + """ assert len(inputs) == len(self._inputs) - 1 inp = [self._inputs[0].owner] + list(inputs) return self.interpret(*inp) @@ -813,7 +1102,7 @@ class InternalGraph: return self.__format__() def __format__(self, format_spec: str = "") -> str: - saved_format_spec = Node.set_format_spec(format_spec) + saved_format_spec = Node._set_format_spec(format_spec) name = "" if self._name: name = "%s.Graph" % self._name @@ -823,7 +1112,7 @@ class InternalGraph: "\n\t".join("{}".format(str(i)) for i in self._exprs), ", ".join(str(i) for i in self._outputs), ) - Node.set_format_spec(saved_format_spec) + Node._set_format_spec(saved_format_spec) return res def __getstate__(self): @@ -1010,7 +1299,7 @@ class TracedModuleBuilder(NodeMixin): for _, g in self._argdef_graph_map.items(): g.compile() if self._is_top: - g._total_ids = (Node.get_total_id(), Expr.get_total_id()) + g._total_ids = (Node._get_next_id(), Expr._get_next_id()) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: @@ -1298,59 +1587,106 @@ class _node_iter: class BaseFilter: - def __init__(self, expr_iter: Iterable): - self._iter = expr_iter + r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc.""" + + def __init__(self, iter: Iterable): + self._iter = iter def __iter__(self): return iter(self._iter) def as_list(self): + r"""Consume this iterator and return its content as a list. + + Returns: + A list of ``Node`` or ``Expr``. + """ return list(self) def as_dict(self): + r"""Construct an ordered dict to map from ``id`` to objects in this iterator. + + Returns: + An :class:`OrderedDict`. + """ return collections.OrderedDict((i._id, i) for i in self) def as_unique(self): + """Assert that this iterator yields only one ``Node`` or ``Expr`` and return it. + + Rerurns: + A ``Node`` or ``Expr``. + """ rst = self.as_list() assert len(rst) == 1, "{} elements found".format(len(rst)) - (expr,) = self - return expr + (elem,) = self + return elem def as_count(self): + r"""Consume this iterator and get the number of elements.""" return sum(1 for _ in self) class ExprFilter(BaseFilter): + """Filter on Expr iterator. + This class is an iterator of :class:`.Expr` objects and multiple + filtering conditions and mappers can be chained. + """ + def call_function(self, func): + r"""Filter by specific ``CallFunction.func``. + See :meth:`~.InternalGraph.get_function_by_type` for details. + """ return ExprFilterCallFunction(self, func) def call_method(self, method): + r"""Filter by specific ``CallMethod.method``. + See :meth:`~.InternalGraph.get_function_by_type` for details. + """ return ExprFilterCallMethod(self, method) def expr_id(self, expr_id: List[int]): + r"""Filter Exprs by their ``id``. + See :meth:`~.InternalGraph.get_function_by_type` for details. + """ return ExprFilterExprId(self, expr_id) class NodeFilter(BaseFilter): - def type(self, owner_type, node_type): - return NodeFilterType(self, owner_type, node_type) + """Filter on Node iterator. + This class is an iterator of :class:`.Node` objects and multiple + filtering conditions and mappers can be chained. + """ + + def type(self, owner_type): + r"""Filter by specific Module type. + See :meth:`~.InternalGraph.get_module_by_type` for details. + """ + return NodeFilterType(self, owner_type) def node_id(self, node_id: List[int]): + r"""Filter Nodes by their ``id``. + See :meth:`~.InternalGraph.get_node_by_id` for details. + """ return NodeFilterNodeId(self, node_id) def name(self, name: str, ignorecase: bool = True): + r"""Filter Nodes by their full name. + See :meth:`~.InternalGraph.get_node_by_name` for details. + """ return NodeFilterName(self, name, ignorecase) class NodeFilterType(NodeFilter): - def __init__(self, expr_iter, owner_type, node_type): + """See :meth:`~.InternalGraph.get_module_by_type`""" + + def __init__(self, expr_iter, owner_type): super().__init__(expr_iter) self.owner_type = owner_type - self.node_type = node_type def __iter__(self): for node in self._iter: - if not isinstance(node, self.node_type): + if not isinstance(node, ModuleNode): continue if not hasattr(node, "owner"): continue @@ -1359,6 +1695,8 @@ class NodeFilterType(NodeFilter): class NodeFilterNodeId(NodeFilter): + """See :meth:`~.InternalGraph.get_node_by_id`""" + def __init__(self, expr_iter, node_id: List[int]): super().__init__(expr_iter) if not isinstance(node_id, Sequence): @@ -1372,6 +1710,8 @@ class NodeFilterNodeId(NodeFilter): class NodeFilterName(NodeFilter): + """See :meth:`~.InternalGraph.get_node_by_name`""" + _re = None def __init__(self, node_iter, pattern, ignorecase): @@ -1399,6 +1739,8 @@ class NodeFilterName(NodeFilter): class ExprFilterCallFunction(ExprFilter): + """See :meth:`~.InternalGraph.get_function_by_type`""" + def __init__(self, expr_iter, func: Callable = None): super().__init__(expr_iter) self.func = func @@ -1412,6 +1754,8 @@ class ExprFilterCallFunction(ExprFilter): class ExprFilterCallMethod(ExprFilter): + """See :meth:`~.InternalGraph.get_method_by_type`""" + def __init__(self, expr_iter, method: str = None): super().__init__(expr_iter) self.method = method @@ -1425,6 +1769,8 @@ class ExprFilterCallMethod(ExprFilter): class ExprFilterExprId(ExprFilter): + """See :meth:`~.InternalGraph.get_expr_by_id`""" + def __init__(self, expr_iter, expr_id: List[int]): super().__init__(expr_iter) if not isinstance(expr_id, Sequence): @@ -1438,8 +1784,16 @@ class ExprFilterExprId(ExprFilter): class TracedModule(Module): - r"""`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.""" - + r"""``TracedModule`` is the Module created by tracing normal module. + + It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule`` + will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs`` + and interpret it. + + .. note:: + ``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module` + for more details. + """ # m_node = None # type: ModuleNode argdef_graph_map = None argdef_outdef_map = None @@ -1475,19 +1829,97 @@ class TracedModule(Module): return outputs def set_watch_points(self, nodes): + r"""Initialize the :attr:`~.TracedModule.watch_points`. + + You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime. + + Args: + nodes: a list of ``Node``. + + For example, the following code + + .. code-block:: + + import megengine.module as M + import megengine as mge + import megengine.traced_module as tm + + class MyModule(M.Module): + def forward(self, x): + x = x + 1 + 2 + return x + + net = MyModule() + + inp = mge.Tensor([0]) + traced_module = tm.trace_module(net, inp) + add_1_node = traced_module.graph.get_node_by_id(2).as_unique() + traced_module.set_watch_points(add_1_node) + + out = traced_module(inp) + + Will get the following ``watch_node_value``:: + + print(traced_module.watch_node_value) + + .. code-block:: text + + {add_out: Tensor([1.], device=xpux:0)} + """ if not isinstance(nodes, Sequence): nodes = [nodes] self.watch_points = nodes + if nodes: + nodes[0].top_graph._watch_point = [] for n in nodes: n.top_graph._watch_point.append(n) def clear_watch_points(self): + r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`. + """ for n in self.watch_points: n.top_graph._watch_point = [] self.watch_points = [] self.watch_node_value = {} - def set_end_points(self, nodes): + def set_end_points(self, nodes: Sequence[Node]): + r"""Initialize the :attr:`~.TracedModule.end_points`. + + When all the ``nodes`` are generated, the Module will stop execution and return directly. + + Args: + nodes: a list of ``Node``. + + For example, the following code + + .. code-block:: + + import megengine.module as M + import megengine as mge + import megengine.traced_module as tm + + class MyModule(M.Module): + def forward(self, x): + x = x + 1 + 2 + return x + + net = MyModule() + + inp = mge.Tensor([0]) + traced_module = tm.trace_module(net, inp) + add_1_node = traced_module.graph.get_node_by_id(2).as_unique() + traced_module.set_end_points(add_1_node) + + out = traced_module(inp) + + Will get the following ``out``:: + + print(out) + + .. code-block:: text + + [Tensor([1.], device=xpux:0)] + """ if not isinstance(nodes, Sequence): nodes = [nodes] self.end_points = nodes @@ -1497,12 +1929,16 @@ class TracedModule(Module): n.top_graph._end_point.append(n) def clear_end_points(self): + r"""Clear the :attr:`~.TracedModule.end_points`. + """ for n in self.end_points: n.top_graph._end_point = [] self.end_points = [] @property def graph(self) -> InternalGraph: + """Return the ``InternalGraph`` of this ``TracedModule`` + """ if self._is_top: self._update_ref() assert len(self.argdef_graph_map) == 1 @@ -1559,9 +1995,10 @@ class TracedModule(Module): obj._update_ref(mnode_map, graph) def flatten(self): - r"""Get a new module, which eliminates ``GetAttr`` and has no hierarchy. + r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy. - :return: :class:`TracedModule` + Retruns: + A new :class:`TracedModule`. """ new_module = copy.deepcopy(self) assert active_module_tracer() is None @@ -1690,16 +2127,35 @@ def cpp_apply_module_trace(opdef, *args): def register_as_builtin(mod_cls: Type[Module]) -> None: - r"""Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. + r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. Args: - mod_cls: the Module class which will be threated as builtin module in tracing + mod_cls: the module class which will be treated as builtin module in tracing. """ module_tracer.register_as_builtin(mod_cls) def wrap(func: Callable): - r"""Call this function to register func as a builtin function.""" + r"""Call this function to register ``func`` as a builtin function. + + This function can be called at module-level scope to register ``func`` as a builtin function. + A builtin function will be converted to a :class:`CallFunction` Expr in tracing:: + + def my_func(x, y): + return x + y + + import megengine.traced_module as tm + tm.wrap(my_func) + + This function can also equivalently be used as a decorator:: + + @tm.wrap + def my_func(x, y): + return x + y + + Args: + func: the function of the global function to insert into the graph when it's called. + """ assert callable(func), "func must be a callable" assert hasattr(func, "__code__") fn_name = func.__code__.co_name @@ -1739,13 +2195,15 @@ def _register_all_builtin_module(): module_tracer.register_as_builtin(TM_FakeQuant) -def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: - r"""Traces module ``mod`` and returns corresponding TracedModule. +def trace_module( + mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any] +) -> TracedModule: + r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`. Args: - mod: the module will be converted to TracedModule - input: the positional arguments passed to forward method of ``mod`` - kwargs: the keyword arguments passed to forward method of ``mod`` + mod: the module will be converted to :class:`TracedModule`. + args: the positional arguments passed to forward method of ``mod``. + kwargs: the keyword arguments passed to forward method of ``mod``. """ assert active_module_tracer() is None assert isinstance(mod, Module) @@ -1756,7 +2214,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: module_tracer(_wrapped_function, _init_id2name(mod, "self")) ) for cls in [Expr, Node]: - cls.set_total_id(0) + cls._set_next_id(0) with active_module_tracer().patcher: global_scope = InternalGraph(name="") active_module_tracer().push_scope(global_scope)