提交 fb20cb36 编写于 作者: M Megvii Engine Team

docs(mge/traced_module): update traced_module api doc

GitOrigin-RevId: 19a95d26c71e672376c5fda00a4e7dc6050e1c6a
上级 c7a8d945
...@@ -130,3 +130,4 @@ import megengine.optimizer ...@@ -130,3 +130,4 @@ import megengine.optimizer
import megengine.quantization import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.traced_module
...@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str): ...@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str):
class Expr: class Expr:
"""``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.""" r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
__total_id = 0 """
inputs = None # type: List[Node] inputs = None # type: List[Node]
r"""The input Nodes of this Expr."""
outputs = None # type: List[Node] outputs = None # type: List[Node]
r"""The output Nodes of this Expr."""
const_val = None # type: List[Any] const_val = None # type: List[Any]
r"""The non-tensor object in the input of the operation."""
arg_def = None # type: TreeDef arg_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
out_def = None # type: TreeDef out_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
__total_id = 0
def __init__(self) -> None: def __init__(self) -> None:
self._id = Expr.__total_id self._id = Expr.__total_id
...@@ -125,6 +132,11 @@ class Expr: ...@@ -125,6 +132,11 @@ class Expr:
return inputs, {} return inputs, {}
def replace_inputs(self, repl_dict: Dict[Node, Node]): def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
"""
while repl_dict: while repl_dict:
node, repl_node = repl_dict.popitem() node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node) assert type(node) == type(repl_node)
...@@ -147,16 +159,19 @@ class Expr: ...@@ -147,16 +159,19 @@ class Expr:
@property @property
def kwargs(self): def kwargs(self):
r"""Get the the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
return kwargs return kwargs
@property @property
def args(self): def args(self):
r"""Get the the positional arguments of the operation corresponding to this Expr."""
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Expr."""
if self._top_graph: if self._top_graph:
return self._top_graph() return self._top_graph()
return None return None
...@@ -168,17 +183,18 @@ class Expr: ...@@ -168,17 +183,18 @@ class Expr:
return state return state
@classmethod @classmethod
def get_total_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
@classmethod @classmethod
def set_total_id(cls, id: int = 0): def _set_next_id(cls, id: int = 0):
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
r"""A fake Expr which is used to mark the input of graph."""
name = None name = None
def __init__(self, name=None, type=None, orig_name=None): def __init__(self, name=None, type=None, orig_name=None):
...@@ -204,13 +220,15 @@ class Input(Expr): ...@@ -204,13 +220,15 @@ class Input(Expr):
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name) return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr): class GetAttr(Expr):
name = None r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
name = None
r"""name: the qualified name of the attribute to be retrieved."""
def __init__(self, module, name, type=None, orig_name=None): def __init__(self, module, name, type=None, orig_name=None):
super().__init__() super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
...@@ -251,6 +269,13 @@ class GetAttr(Expr): ...@@ -251,6 +269,13 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
Args:
node: the Node to be called.
method: the method name.
Default: "__call__"
"""
def __init__(self, node, method="__call__"): def __init__(self, node, method="__call__"):
super().__init__() super().__init__()
if isinstance(node, type): if isinstance(node, type):
...@@ -320,8 +345,12 @@ class CallMethod(Expr): ...@@ -320,8 +345,12 @@ class CallMethod(Expr):
# expr: outputs = apply(self.opdef, *inputs) # expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr): class Apply(Expr):
opdef = None r"""``Apply`` represents a call to :func:`apply`.
Args:
opdef: the applied :class:`OpDef`.
"""
opdef = None
def __init__(self, opdef): def __init__(self, opdef):
super().__init__() super().__init__()
assert isinstance(opdef, OpDef) assert isinstance(opdef, OpDef)
...@@ -388,6 +417,11 @@ class Apply(Expr): ...@@ -388,6 +417,11 @@ class Apply(Expr):
class CallFunction(Expr): class CallFunction(Expr):
r"""``CallFunction`` represents a call to a built-in function.
Args:
func: a built-in function.
"""
def __init__(self, func): def __init__(self, func):
super().__init__() super().__init__()
assert isinstance(func, Callable) assert isinstance(func, Callable)
...@@ -425,7 +459,14 @@ class CallFunction(Expr): ...@@ -425,7 +459,14 @@ class CallFunction(Expr):
# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
Args:
c: a const Tensor or Module.
name: the name of output Node.
"""
value = None value = None
r"""The const Tensor or Module"""
# TODO: constant cache to reduce the size of dumped model # TODO: constant cache to reduce the size of dumped model
_constant_cache = {} _constant_cache = {}
......
...@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor ...@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin): class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
def __init__( def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
): ):
...@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin): ...@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
return self.qparams return self.qparams
def set_qparams(self, qparams: QParams): def set_qparams(self, qparams: QParams):
r""" r"""Initialize :attr:`~.FakeQuantize.qparams`.
Args: Args:
qparams: used to set initial scale. qparams: used to set initial ``scale`` and ``zero_point``.
""" """
if qparams.scale is None: if qparams.scale is None:
raise AssertionError("Can not get an initialized scale") raise AssertionError("Can not get an initialized scale")
......
...@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type ...@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type
import numpy import numpy
from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module from ..module import Module
from ..tensor import Tensor from ..tensor import Tensor
logger = get_logger(__name__)
class Node:
r"""``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method.
They are inputs/outputs of Expr(the operations on variables).
Args: class Node:
expr: the Expr which produces the node r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
name: the name of the node They are inputs/outputs of Expr (the operations on variables).
""" """
expr = None expr = None # type: Expr
__total_id = 0 r"""The Expr which produces the Node."""
_id = None __total_id = 0 # type: int
_id = None # type: int
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None _name = None # type: str
_orig_name = None _orig_name = None # type: str
_format_spec = "" _format_spec = "" # type: str
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): def __init__(self, expr: "Expr", name: str, orig_name: str):
self.expr = expr self.expr = expr
self.users = [] # List[Expr] self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
...@@ -73,32 +73,51 @@ class Node: ...@@ -73,32 +73,51 @@ class Node:
else: else:
return name if name else ("%d" % self._id) return name if name else ("%d" % self._id)
@property
def name(self):
r"""Return the name of this Node."""
return self._name
@name.setter
def name(self, new_name: str):
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._used_names, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = graph._create_unique_name(new_name)
self._name = new_name
self._orig_name = new_name
@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Node."""
if self._top_graph: if self._top_graph:
return self._top_graph() return self._top_graph()
return None return None
@classmethod @classmethod
def set_format_spec(cls, str): def _set_format_spec(cls, str):
old_format_spec = cls._format_spec old_format_spec = cls._format_spec
cls._format_spec = str cls._format_spec = str
return old_format_spec return old_format_spec
@classmethod @classmethod
def get_total_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
@classmethod @classmethod
def set_total_id(cls, id: int = 0): def _set_next_id(cls, id: int = 0):
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id
class ModuleNode(Node): class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects.""" r"""``ModuleNode`` represents the Module objects."""
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType _owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
...@@ -116,6 +135,11 @@ class ModuleNode(Node): ...@@ -116,6 +135,11 @@ class ModuleNode(Node):
@property @property
def owner(self): def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``.
Returns:
An :calss:`~.Module`.
"""
if self._owner: if self._owner:
return self._owner() return self._owner()
return None return None
...@@ -145,6 +169,7 @@ class TensorNode(Node): ...@@ -145,6 +169,7 @@ class TensorNode(Node):
@property @property
def shape(self): def shape(self):
r"""Get the shape of this Node."""
return self._shape return self._shape
@shape.setter @shape.setter
...@@ -153,6 +178,7 @@ class TensorNode(Node): ...@@ -153,6 +178,7 @@ class TensorNode(Node):
@property @property
def dtype(self): def dtype(self):
r"""Get the dtype of this Node."""
return self._dtype return self._dtype
@dtype.setter @dtype.setter
...@@ -161,6 +187,7 @@ class TensorNode(Node): ...@@ -161,6 +187,7 @@ class TensorNode(Node):
@property @property
def device(self): def device(self):
r"""Get the device of this Node pointed Tensor."""
return self._device return self._device
@device.setter @device.setter
...@@ -169,6 +196,7 @@ class TensorNode(Node): ...@@ -169,6 +196,7 @@ class TensorNode(Node):
@property @property
def qparams(self): def qparams(self):
r"""Get the :calss:`QParams` of this Node."""
return self._qparams return self._qparams
@qparams.setter @qparams.setter
...@@ -177,10 +205,16 @@ class TensorNode(Node): ...@@ -177,10 +205,16 @@ class TensorNode(Node):
@property @property
def value(self): def value(self):
r"""Get the bound Tensor of this Node."""
return self._value return self._value
@value.setter @value.setter
def value(self, value): def value(self, value):
r"""Bind a Tensor to this Node.
Args:
value: A :class:`Tensor`.
"""
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None) setattr(value, "_NodeMixin__node", None)
self._value = value self._value = value
......
...@@ -150,6 +150,9 @@ def tree_flatten( ...@@ -150,6 +150,9 @@ def tree_flatten(
is_leaf: Callable = _is_leaf, is_leaf: Callable = _is_leaf,
is_const_leaf: Callable = _is_const_leaf, is_const_leaf: Callable = _is_const_leaf,
): ):
r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used
to reconstruct the object.
"""
if type(values) not in SUPPORTED_TYPE: if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values assert is_leaf(values), values
node = LeafDef(leaf_type(values)) node = LeafDef(leaf_type(values))
...@@ -169,6 +172,15 @@ def tree_flatten( ...@@ -169,6 +172,15 @@ def tree_flatten(
class TreeDef: class TreeDef:
r"""A ``TreeDef`` represents the structure of a pytree.
Args:
type: the type of root Node of the pytree.
aux_data: some const data that is useful in unflattening the pytree.
children_defs: ``TreeDef`` for each child of the root Node.
num_leaves: the number of leaves.
"""
def __init__(self, type, aux_data, children_defs): def __init__(self, type, aux_data, children_defs):
self.type = type self.type = type
self.aux_data = aux_data self.aux_data = aux_data
...@@ -176,6 +188,9 @@ class TreeDef: ...@@ -176,6 +188,9 @@ class TreeDef:
self.num_leaves = sum(ch.num_leaves for ch in children_defs) self.num_leaves = sum(ch.num_leaves for ch in children_defs)
def unflatten(self, leaves): def unflatten(self, leaves):
r"""Given a list of values and a ``TreeDef``, builds a object.
This is the inverse operation of ``tree_flatten``.
"""
assert len(leaves) == self.num_leaves assert len(leaves) == self.num_leaves
start = 0 start = 0
children = [] children = []
...@@ -196,13 +211,10 @@ class TreeDef: ...@@ -196,13 +211,10 @@ class TreeDef:
) )
) )
def __lt__(self, other): def __ne__(self, other) -> bool:
return self.__hash__() < other.__hash__() return not self.__eq__(other)
def __gt__(self, other):
return self.__hash__() > other.__hash__()
def __eq__(self, other): def __eq__(self, other) -> bool:
return ( return (
self.type == other.type self.type == other.type
and self.aux_data == other.aux_data and self.aux_data == other.aux_data
...@@ -227,6 +239,9 @@ class LeafDef(TreeDef): ...@@ -227,6 +239,9 @@ class LeafDef(TreeDef):
assert isinstance(leaves[0], self.type), self.type assert isinstance(leaves[0], self.type), self.type
return leaves[0] return leaves[0]
def __ne__(self, other) -> bool:
return not self.__eq__(other)
def __eq__(self, other): def __eq__(self, other):
if isinstance(self.const_val, np.ndarray): if isinstance(self.const_val, np.ndarray):
return self.type == other.type and (self.const_val == other.const_val).all() return self.type == other.type and (self.const_val == other.const_val).all()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册