提交 20cdd11d 编写于 作者: M Megvii Engine Team

docs(mge): fix docstring of cgtools

GitOrigin-RevId: ee8c93b9f7d4f4841b7690c3c764b41ec9f93dc1
上级 57c4eccf
......@@ -66,10 +66,10 @@ class Function:
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`.
.. note::
.. note::
In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
.. note::
......
......@@ -225,8 +225,8 @@ def square(x: Tensor) -> Tensor:
"""
Returns a new tensor with the square of the elements of input tensor.
:param inp: The input tensor
:return: The computed tensor
:param inp: input tensor.
:return: computed tensor.
Examples:
......
......@@ -72,7 +72,7 @@ def isinf(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is ``Inf`` or not.
:param inp: input tensor.
:return: c.
:return: result tensor.
Examples:
......
......@@ -88,9 +88,9 @@ def load(f, map_location=None, pickle_module=pickle):
:type map_location: str, dict or a function specifying the map rules
:param map_location: Default: ``None``.
.. note::
.. note::
map_location defines device mapping. See examples for usage.
map_location defines device mapping. See examples for usage.
:type pickle_module:
:param pickle_module: Default: ``pickle``.
......
......@@ -17,8 +17,8 @@ from ..core.tensor.raw_tensor import as_raw_tensor
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
"""return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, return all types.
"""Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, returns all types.
"""
outputs = []
memo = set()
......@@ -46,14 +46,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
"""get the inputs of owner opr of a variable
"""Gets the inputs of owner opr of a variable.
"""
assert isinstance(var, VarNode)
return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str:
"""get the type of owner opr of a variable
"""Gets the type of owner opr of a variable.
"""
assert isinstance(var, VarNode)
......@@ -61,16 +61,16 @@ def get_owner_opr_type(var: VarNode) -> str:
def get_opr_type(opr: OperatorNode) -> str:
"""get the type of a opr
"""Gets the type of an opr.
"""
assert isinstance(opr, OperatorNode)
return opr.type
def graph_traversal(outputs: VarNode):
"""helper function to traverse the computing graph and return enough useful information
"""Helper function to traverse the computing graph and return enough useful information.
:param outputs: model outputs
:param outputs: model outputs.
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
......@@ -124,11 +124,11 @@ def graph_traversal(outputs: VarNode):
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
"""get oprs in some topological order for a dumped model
"""Gets oprs in some topological order for a dumped model.
:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
:param outputs: model outputs.
:param prune_reshape: whether to prune the useless operators during inference.
:return: opr list with some correct execution order.
"""
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
......@@ -194,13 +194,13 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
"""replace vars in the graph
"""Replaces vars in the graph.
:param dst: target vars representing the graph
:param varmap: the map that specifies how to replace the vars
:param dst: target vars representing the graph.
:param varmap: the map that specifies how to replace the vars.
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
replaced.
"""
dst_vec = []
repl_src_vec = []
......@@ -221,13 +221,13 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
def replace_oprs(
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
"""Replace operators in the graph.
"""Replaces operators in the graph.
:param dst: target vars representing the graph
:param oprmap: the map that specifies how to replace the operators
:param dst: target vars representing the graph.
:param oprmap: the map that specifies how to replace the operators.
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
replaced.
"""
dst_vec = []
repl_src_vec = []
......@@ -246,9 +246,9 @@ def replace_oprs(
def set_priority_to_id(dest_vars):
"""For all oprs in the subgraph constructed by dest_vars
set its priority to id if its original priority is zero
:param dest_vars: target vars representing the graph
"""For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
......@@ -258,11 +258,11 @@ def set_priority_to_id(dest_vars):
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""Load a serialized computing graph and run inference with input data.
"""Loads a serialized computing graph and run inference with input data.
:param file: Path or Handle of the input file.
:param inp_data_list: List of input data.
:return: List of inference results.
:param file: path or handle of the input file.
:param inp_data_list: list of input data.
:return: list of inference results.
"""
*_, out_list = G.load_graph(file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册