提交 0c54f2dc 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

docs(mge/core): refine the docstring of several apis

GitOrigin-RevId: c03fecfa6abf86b9e47e8e23e45adf55131621e6
上级 80d9e56b
...@@ -27,7 +27,7 @@ def get_device_count(device_type: str) -> int: ...@@ -27,7 +27,7 @@ def get_device_count(device_type: str) -> int:
def is_cuda_available() -> bool: def is_cuda_available() -> bool:
""" Returns whether cuda is avaiable. """Returns whether cuda device is available on this system.
""" """
return mgb.config.get_device_count("gpu", warn=False) > 0 return mgb.config.get_device_count("gpu", warn=False) > 0
......
...@@ -100,10 +100,11 @@ class Function(metaclass=ABCMeta): ...@@ -100,10 +100,11 @@ class Function(metaclass=ABCMeta):
Users can call :meth:`~.function.Function.save_for_backward` in this method to save tensors. Users can call :meth:`~.function.Function.save_for_backward` in this method to save tensors.
:param input: Input tensors. :param input: Input tensors.
:return: A tuple of Tensor or a single Tensor.
.. note:: .. note::
This method should return a tuple of Tensor representing the output This method should return a tuple of Tensor or a single Tensor representing the output
of the function. of the function.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -113,7 +114,7 @@ class Function(metaclass=ABCMeta): ...@@ -113,7 +114,7 @@ class Function(metaclass=ABCMeta):
self, *output_grads: Iterable[Union[Tensor, None]] self, *output_grads: Iterable[Union[Tensor, None]]
) -> Union[Tuple[Tensor], Tensor]: ) -> Union[Tuple[Tensor], Tensor]:
""" """
Compute the gradient of the function. It must be overriden by all subclasses. Compute the gradient of the forward function. It must be overriden by all subclasses.
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`
...@@ -124,17 +125,17 @@ class Function(metaclass=ABCMeta): ...@@ -124,17 +125,17 @@ class Function(metaclass=ABCMeta):
.. note:: .. note::
This method should return a tuple containing gradients of all This method should return a tuple which containing the gradients of all inputs, in the same order
inputs, in the same order as the ``inputs`` argument of :meth:`~.function.Function.forward` . A as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned
``Tensor`` could be returned instead if there is only one input. instead if there is only one input. If users want to stop the propagation of some gradients,
If users want to stop the propagation of some gradients, the corresponding returned values should be ``None`` . the corresponding returned values should be set ``None`` .
""" """
raise NotImplementedError raise NotImplementedError
def save_for_backward(self, *tensors: Iterable[Tensor]): def save_for_backward(self, *tensors: Iterable[Tensor]):
""" """
Saves tensors for gradient computation. This method should be called only Saves tensors needed for gradient computation. This method should be called only
once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously. once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously.
The saved tensors can be accessed through the ``saved_tensors`` attribute. The saved tensors can be accessed through the ``saved_tensors`` attribute.
......
...@@ -36,7 +36,7 @@ _default_graph = _DefaultGraph() ...@@ -36,7 +36,7 @@ _default_graph = _DefaultGraph()
class Graph(mgb.CompGraph): class Graph(mgb.CompGraph):
r""" r"""
A ``comp_graph`` class supporting context management. A computing graph that supporting context management.
:param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``. :param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``.
:param eager_evaluation: use dynamic graph(``True``) or static graph(``False``). :param eager_evaluation: use dynamic graph(``True``) or static graph(``False``).
...@@ -97,7 +97,7 @@ def _use_default_if_none(device, comp_graph): ...@@ -97,7 +97,7 @@ def _use_default_if_none(device, comp_graph):
def dump(outputs, fpath, optimize_options=None, **kwargs): def dump(outputs, fpath, optimize_options=None, **kwargs):
r""" r"""
Serializes this computing graph and writes result to a file. Serializes this computing graph and writes it to a file.
:type outputs: ``Tensor`` or a collection of ``Tensor`` :type outputs: ``Tensor`` or a collection of ``Tensor``
:param outputs: output variables that need to be retrieved when :param outputs: output variables that need to be retrieved when
...@@ -105,7 +105,7 @@ def dump(outputs, fpath, optimize_options=None, **kwargs): ...@@ -105,7 +105,7 @@ def dump(outputs, fpath, optimize_options=None, **kwargs):
:type fpath: ``str`` :type fpath: ``str``
:param fpath: path for the output file :param fpath: path for the output file
:type optimize_options: ``list`` :type optimize_options: ``list``
:param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them. :param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them.
.. note:: .. note::
...@@ -115,7 +115,7 @@ def dump(outputs, fpath, optimize_options=None, **kwargs): ...@@ -115,7 +115,7 @@ def dump(outputs, fpath, optimize_options=None, **kwargs):
``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices; ``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices;
``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only in NHWCD4 format. ``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only when ``use_nhwcd4`` is set.
""" """
from .tensor import Tensor from .tensor import Tensor
......
...@@ -139,7 +139,7 @@ class Tensor: ...@@ -139,7 +139,7 @@ class Tensor:
return tensor(data=obj, device=self.device) return tensor(data=obj, device=self.device)
def numpy(self): def numpy(self):
r"""Return the tensor value in ndarray format. r"""Return the tensor value in numpy.ndarray format.
""" """
if self.__val is not None: if self.__val is not None:
assert self.__sym is None assert self.__sym is None
...@@ -235,6 +235,8 @@ class Tensor: ...@@ -235,6 +235,8 @@ class Tensor:
self.__val.reset_zero() self.__val.reset_zero()
def to(self, device): def to(self, device):
r"""Performs Tensor device conversion, returns Tensor with the specified device.
"""
return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device) return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
# https://docs.python.org/3/reference/datamodel.html#object.__hash__ # https://docs.python.org/3/reference/datamodel.html#object.__hash__
......
...@@ -22,6 +22,9 @@ def scalar( ...@@ -22,6 +22,9 @@ def scalar(
device: Optional[mgb.CompNode] = None, device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None, comp_graph: Optional[mgb.CompGraph] = None,
) -> Tensor: ) -> Tensor:
"""
convert ``value`` to the type of :class:`~.Tensor`.
"""
device, comp_graph = _use_default_if_none(device, comp_graph) device, comp_graph = _use_default_if_none(device, comp_graph)
return Tensor(mgb.make_immutable(device, comp_graph, value, dtype=dtype, name=None)) return Tensor(mgb.make_immutable(device, comp_graph, value, dtype=dtype, name=None))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册