提交 f71251fd 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

docs(mge/functional): refine the docstring of several apis

GitOrigin-RevId: 55f61be2a8e215cb1bb9c0f3938408c4e189aa52
上级 0ba52ad7
......@@ -24,17 +24,19 @@ def grad(
use_virtual_grad: bool = None,
return_zero_for_nodep: bool = True,
) -> Union[Tensor, Iterable[Optional[Tensor]], None]:
r"""compute symbolic grad
r"""Compute the symbolic gradient of ``target`` with repect to ``wrt``.
:param target: grad target var
:param wrt: with respect to which to compute the grad
``wrt`` can either be a single tensor or a sequence of tensors.
:param target: ``grad`` target tensor
:param wrt: with respect to which to compute the gradient
:param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint
:param use_virtual_grad: whether to use virtual grad opr, so fwd graph can
be optimized before applying grad; if ``None`` is given, then virtual
grad would be used if ``graph_opt_level >= 2``
:param use_virtual_grad: whether to use virtual ``grad`` opr, so fwd graph can
be optimized before applying ``grad``; if ``None`` is given, then virtual
``grad`` would be used if ``graph_opt_level >= 2``
:param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return
a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using
virtual grad opr.
virtual ``grad`` opr.
:return: :math:`\partial\text{target} / \partial\text{wrt}`
"""
if not isinstance(wrt, mgb.SymbolVar):
......
......@@ -48,12 +48,12 @@ def sum(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tens
@wrap_io_tensor
def prod(inp: Tensor, axis: Optional[int] = None, keepdims=False) -> Tensor:
r"""
Returns prod of input tensor along given *axis*.
Returns the element product of input tensor along given *axis*.
:param inp: The input tensor
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``
:param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False``
:return: The output tensor
:param inp: The input tensor
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``
:param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False``
:return: The output tensor
Examples:
......
......@@ -27,6 +27,11 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
"""Applies a linear transformation to the input.
Refer to :class:`~.Linear` for more information.
:param inp: the input tensor with shape `(N, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param bias: the bias with shape `(out_features,)`.
Default: ``None``
"""
orig_shape = inp.shape
inp = inp.reshape(-1, orig_shape[-1])
......@@ -51,6 +56,8 @@ def conv2d(
) -> Tensor:
"""2D convolution operation.
Refer to :class:`~.Conv2d` for more information.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
......@@ -73,7 +80,6 @@ def conv2d(
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
Refer to :class:`~.Conv2d` for more information.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
......@@ -114,6 +120,8 @@ def conv_transpose2d(
) -> Tensor:
"""2D transposed convolution operation.
Refer to :class:`~.ConvTranspose2d` for more information.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
......@@ -136,7 +144,6 @@ def conv_transpose2d(
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
Refer to :class:`~.ConvTranspose2d` for more information.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
......@@ -172,13 +179,14 @@ def max_pool2d(
) -> Tensor:
"""Applies a 2D max pooling over an input.
Refer to :class:`~.MaxPool2d` for more information.
:param inp: The input tensor.
:param kernel_size: The size of the window.
:param stride: The stride of the window. If not provided, its value is set to ``kernel_size``.
Default: None
:param padding: Implicit zero padding to be added on both sides. Default: 0
Refer to :class:`~.MaxPool2d` for more information.
"""
kh, kw = _pair_nonzero(kernel_size)
......@@ -207,13 +215,14 @@ def avg_pool2d(
) -> Tensor:
""" Applies a 2D average pooling over an input.
Refer to :class:`~.AvgPool2d` for more information.
:param inp: The input tensor.
:param kernel_size: The size of the window.
:param stride: The stride of the window. If not provided, its value is set to ``kernel_size``.
Default: None
:param padding: Implicit zero padding to be added on both sides. Default: 0
Refer to :class:`~.AvgPool2d` for more information.
"""
kh, kw = _pair_nonzero(kernel_size)
sh, sw = _pair_nonzero(stride or kernel_size)
......@@ -343,6 +352,8 @@ def batch_norm2d(
) -> Tensor:
"""Applies batch normalization to the input.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
:param inp: input tensor.
:param running_mean: tensor to store running mean.
:param running_var: tensor to store running variance.
......@@ -358,7 +369,6 @@ def batch_norm2d(
:param eps: a value added to the denominator for numerical stability.
Default: 1e-5.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
"""
inp = mgb.opr.mark_no_broadcast_elemwise(inp)
......@@ -539,10 +549,10 @@ def eye(
Fills the 2-dimensional input :class:`SymbolVar` with the identity matrix.
:param n: The number of rows
:param m: The number of columns, default to None
:param dtype: The data type, default to None
:param device: Compute node of the matrix, defaults to None
:param comp_graph: Compute graph of the matrix, defaults to None
:param m: The number of columns. Default: None
:param dtype: The data type. Default: None
:param device: Compute node of the matrix. Default: None
:param comp_graph: Compute graph of the matrix. Default: None
:return: The eye matrix
Examples:
......@@ -669,9 +679,7 @@ def interpolate(
:param size: size of the output tensor. Default: ``None``
:param scale_factor: scaling factor of the output tensor. Default: ``None``
:param mode: interpolation methods, acceptable values are:
'bilinear'(default), 'linear', 'nearest' (todo), 'cubic' (todo), 'area' (todo)
'BILINEAR', 'LINEAR'. Default: ``BILINEAR``
Examples:
......@@ -701,7 +709,7 @@ def interpolate(
"""
mode = mode.upper()
if mode not in ["BILINEAR", "LINEAR"]:
raise ValueError("interpolate only support bilinear mode")
raise ValueError("interpolate only support linear or bilinear mode")
if mode not in ["BILINEAR", "LINEAR"]:
if align_corners is not None:
raise ValueError(
......
......@@ -179,10 +179,9 @@ def concat(
Concat some tensors
:param inps: Input tensors to concat
:param axis: the dimension over which the tensors are concatenated,
default to 0
:param device: The comp node output on, default to None
:param comp_graph: The graph in which output is, default to None
:param axis: the dimension over which the tensors are concatenated. Default: 0
:param device: The comp node output on. Default: None
:param comp_graph: The graph in which output is. Default: None
:return: The output tensor
Examples:
......
......@@ -23,10 +23,10 @@ def _decide_comp_node_and_comp_graph(*args: mgb.SymbolVar):
return _use_default_if_none(None, None)
def accuracy(logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1):
r"""
Classification accuracy given model predictions and ground-truth labels,
result between 0. to 1.
def accuracy(
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
) -> Union[Tensor, Iterable[Tensor]]:
r"""Calculate the classification accuracy given predicted logits and ground-truth labels.
:param logits: Model predictions of shape [batch_size, num_classes],
representing the probability (likelyhood) of each class.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册