提交 22286579 编写于 作者: M Megvii Engine Team

docs(mge/functional): enhance functional statstic func docstring

GitOrigin-RevId: 1045aecf15801d8d7c22bdd691f13038c7d1c2ce
上级 213f4043
...@@ -120,24 +120,58 @@ def sum( ...@@ -120,24 +120,58 @@ def sum(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions, r"""Calculates the sum of tensor elements over a given axis (or axes).
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. axis: axis or axes along which sums must be computed.
Default: None By default, the sum must be computed over the entire tensor.
keepdims: whether the output tensor has axis retained or not. If a sequence of integers, sums must be computed over multiple axes.
Default: False keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the sum was computed over the entire tensor, a zero-dimensional tensor containing the sum;
otherwise, a tensor containing the sums.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
Examples: .. admonition:: Special Cases
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) Let ``N`` equal the number of elements over which to compute the sum.
* If ``N`` is 0, the sum is ``0`` (i.e., the empty sum).
* If :math:`x_i` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).
.. warning::
If the accumulator is too small, overflow occurs:
>>> x = F.ones(128, dtype="int8")
>>> F.sum(x) >>> F.sum(x)
Tensor(21, dtype=int32, device=xpux:0) Tensor(-128, dtype=int8, device=xpux:0)
Examples:
The sum of an empty tensor is the neutral element 0:
>>> F.sum(Tensor([]))
Tensor(0.0, device=xpux:0)
Normal case:
>>> F.sum(Tensor([1, 2, 3]))
Tensor(6, dtype=int32, device=xpux:0)
>>> F.sum(Tensor([0.5, 1.5]))
Tensor(2.0, device=xpux:0)
Along an axis:
>>> F.sum(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([5 7 9], dtype=int32, device=xpux:0)
>>> F.sum(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([ 6 15], dtype=int32, device=xpux:0)
""" """
return inp.sum(axis=axis, keepdims=keepdims) return inp.sum(axis=axis, keepdims=keepdims)
...@@ -145,22 +179,58 @@ def sum( ...@@ -145,22 +179,58 @@ def sum(
def prod( def prod(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
) -> Tensor: ) -> Tensor:
r"""Returns the product of input tensor along given axis. If axis is a list of dimensions, r"""Calculates the product of tensor elements over a given axis (or axes).
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which products must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the product must be computed over the entire tensor.
If a sequence of integers, products must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the product was computed over the entire tensor, a zero-dimensional tensor containing the products;
otherwise, a non-zero-dimensional tensor containing the products.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
Examples: .. admonition:: Special Cases
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) Let ``N`` equal the number of elements over which to compute the product.
* If ``N`` is 0, the product is ``1`` (i.e., the empty product).
* If :math:`x_i` is ``NaN``, the product is ``NaN`` (i.e., ``NaN`` values propagate).
.. warning::
Arithmetic is modular when using integer types, and no error is raised on overflow:
>>> x = Tensor([536870910, 536870910, 536870910, 536870910])
>>> F.prod(x) >>> F.prod(x)
Tensor(720, dtype=int32, device=xpux:0) Tensor(16, dtype=int32, device=xpux:0)
Examples:
The product of an empty tensor is the neutral element 1:
>>> F.prod(Tensor([]))
Tensor(1.0, device=xpux:0)
Normal case:
>>> F.prod(Tensor([1, 2, 3]))
Tensor(6, dtype=int32, device=xpux:0)
>>> F.prod(Tensor([0.5, 1.5]))
Tensor(0.75, device=xpux:0)
Along an axis:
>>> F.prod(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([ 4 10 18], dtype=int32, device=xpux:0)
>>> F.prod(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([ 6 120], dtype=int32, device=xpux:0)
""" """
return inp.prod(axis=axis, keepdims=keepdims) return inp.prod(axis=axis, keepdims=keepdims)
...@@ -170,24 +240,44 @@ def mean( ...@@ -170,24 +240,44 @@ def mean(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the mean value of input tensor along r"""Calculates the mean of tensor elements over a given axis (or axes).
given axis. If axis is a list of dimensions,
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which means must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the mean must be computed over the entire tensor.
If a sequence of integers, means must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the mean was computed over the entire tensor, a zero-dimensional tensor containing the mean;
otherwise, a non-zero-dimensional tensor containing the means.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. admonition:: Special Cases
Let ``N`` equal the number of elements over which to compute the mean.
* If ``N`` is 0, the mean is ``NaN``.
* If :math:`x_i` is ``NaN``, the mean is ``NaN`` (i.e., ``NaN`` values propagate).
Examples: Examples:
>>> F.mean(Tensor([1, 2, 3]))
Tensor(2.0, device=xpux:0)
>>> import numpy as np >>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) >>> F.mean(Tensor([1, np.nan, 3]))
>>> out = F.mean(x) Tensor(nan, device=xpux:0)
>>> out.numpy()
array(3.5, dtype=float32) Along an axis:
>>> F.mean(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([2.5 3.5 4.5], device=xpux:0)
>>> F.mean(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([2. 5.], device=xpux:0)
""" """
return inp.mean(axis=axis, keepdims=keepdims) return inp.mean(axis=axis, keepdims=keepdims)
...@@ -197,24 +287,36 @@ def var( ...@@ -197,24 +287,36 @@ def var(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the variance value of input tensor along r"""Calculates the variance of tensor elements over a given axis (or axes).
given axis. If axis is a list of dimensions,
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which variances must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the variance must be computed over the entire tensor.
If a sequence of integers, variances must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the variance was computed over the entire tensor, a zero-dimensional tensor containing the variance;
otherwise, a non-zero-dimensional tensor containing the variances.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. note::
The variance is the average of the squared deviations from the mean,
i.e., ``var = mean(x)``, where ``x = abs(a - a.mean())**2``.
Examples: Examples:
>>> import numpy as np >>> x = Tensor([[1, 2], [3, 4]])
>>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) >>> F.var(x)
>>> out = F.var(data) Tensor(1.25, device=xpux:0)
>>> out.numpy().round(decimals=4)
2.9167 >>> x = Tensor([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]])
>>> F.var(x)
Tensor(6.8333335, device=xpux:0)
""" """
if axis is None: if axis is None:
m = mean(inp, axis=axis, keepdims=False) m = mean(inp, axis=axis, keepdims=False)
...@@ -229,24 +331,35 @@ def std( ...@@ -229,24 +331,35 @@ def std(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the standard deviation of input tensor along r"""Calculates the standard deviation of tensor elements over a given axis (or axes).
given axis. If axis is a list of dimensions,
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which standard deviations must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the standard deviation must be computed over the entire tensor.
If a sequence of integers, standard deviations must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the standard deviation was computed over the entire tensor, a zero-dimensional tensor containing the standard deviation;
otherwise, a non-zero-dimensional tensor containing the standard deviations.
.. note::
The standard deviation is the square root of the average of the squared deviations from the mean,
i.e., ``std = sqrt(mean(x))``, where ``x = abs(a - a.mean())**2``.
Examples: Examples:
>>> import numpy as np >>> x = Tensor([[1, 2], [3, 4]])
>>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) >>> F.std(x)
>>> out = F.std(data, axis=1) Tensor(1.118034, device=xpux:0)
>>> out.numpy().round(decimals=4)
array([0.8165, 0.8165], dtype=float32) >>> x = Tensor([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]])
>>> F.std(x)
Tensor(2.6140645, device=xpux:0)
""" """
return var(inp, axis=axis, keepdims=keepdims) ** 0.5 return var(inp, axis=axis, keepdims=keepdims) ** 0.5
...@@ -256,23 +369,38 @@ def min( ...@@ -256,23 +369,38 @@ def min(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the min value of input tensor along r"""Calculates the minimum of tensor elements over a given axis (or axes).
given axis. If axis is a list of dimensions,
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which minimums must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the minimum must be computed over the entire tensor.
If a sequence of integers, minimums must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the minimum was computed over the entire tensor, a zero-dimensional tensor containing the minimum;
otherwise, a non-zero-dimensional tensor containing the minimums.
.. admonition:: Special Cases
If :math:`x_i` is ``NaN``, the minimum is ``NaN`` (i.e., ``NaN`` values propagate).
Examples: Examples:
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) >>> x = Tensor([[1, 2], [3, 4]])
>>> F.min(x) >>> F.min(x)
Tensor(1, dtype=int32, device=xpux:0) Tensor(1, dtype=int32, device=xpux:0)
Along an axis:
>>> F.min(x, axis=0)
Tensor([1 2], dtype=int32, device=xpux:0)
>>> F.min(x, axis=1)
Tensor([1 3], dtype=int32, device=xpux:0)
""" """
return inp.min(axis=axis, keepdims=keepdims) return inp.min(axis=axis, keepdims=keepdims)
...@@ -282,61 +410,42 @@ def max( ...@@ -282,61 +410,42 @@ def max(
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Tensor: ) -> Tensor:
r"""Returns the max value of the input tensor along r"""Calculates the maximum of tensor elements over a given axis (or axes).
given axis. If axis is a list of dimensions,
reduce over all of them.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None axis: axis or axes along which maximums must be computed.
keepdims: whether the output tensor has axis retained or not. Default: False By default, the maximum must be computed over the entire tensor.
If a sequence of integers, maximums must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns: Returns:
output tensor. if the maximum was computed over the entire tensor, a zero-dimensional tensor containing the maximum;
otherwise, a non-zero-dimensional tensor containing the maximums.
Examples: .. admonition:: Special Cases
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
>>> F.max(x)
Tensor(6, dtype=int32, device=xpux:0)
"""
return inp.max(axis=axis, keepdims=keepdims)
If :math:`x_i` is ``NaN``, the maximum is ``NaN`` (i.e., ``NaN`` values propagate).
def norm( Examples:
inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
):
r"""Calculates ``p``-norm of input tensor along
given axis.
Args: >>> x = Tensor([[1, 2], [3, 4]])
inp: input tensor. >>> F.max(x)
ord: power of value applied to inp. Default: 2 Tensor(4, dtype=int32, device=xpux:0)
axis: dimension to reduce. If None, input must be a vector. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
Returns: Along an axis:
output tensor.
Examples: >>> F.max(x, axis=0)
>>> import numpy as np Tensor([3 4], dtype=int32, device=xpux:0)
>>> x = Tensor(np.arange(-3, 3, dtype=np.float32)) >>> F.max(x, axis=1)
>>> out = F.norm(x) Tensor([2 4], dtype=int32, device=xpux:0)
>>> out.numpy().round(decimals=4)
4.3589
""" """
if axis is None: return inp.max(axis=axis, keepdims=keepdims)
if inp.ndim != 1:
raise TypeError("axis is required unless input is a vector")
if ord is None: # searching functions
ord = 2
if ord == 0:
return sum(inp != 0, axis=axis, keepdims=keepdims)
if ord == math.inf:
return max(abs(inp))
if ord == -math.inf:
return min(abs(inp))
return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
def argmin( def argmin(
...@@ -433,31 +542,7 @@ def argmax( ...@@ -433,31 +542,7 @@ def argmax(
return result return result
def normalize( # sorting functions
inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
) -> Tensor:
r"""Performs :math:`L_p` normalization of input tensor along
given axis.
For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
Args:
inp: input tensor.
ord: power of value applied to input tensor. Default: 2
axis: dimension to reduce.If None, input must be a vector. Default: None
eps: a small value to avoid division by zero. Default: 1e-12
Returns:
normalized output tensor.
"""
if axis is None:
return inp / clip(norm(inp, ord, axis), lower=eps)
else:
return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
def argsort(inp: Tensor, descending: bool = False) -> Tensor: def argsort(inp: Tensor, descending: bool = False) -> Tensor:
...@@ -589,6 +674,9 @@ def topk( ...@@ -589,6 +674,9 @@ def topk(
return tns, ind return tns, ind
# linear algebra functions
def matinv(inp: Tensor) -> Tensor: def matinv(inp: Tensor) -> Tensor:
r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n]. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
...@@ -725,6 +813,109 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: ...@@ -725,6 +813,109 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, S, Vh return U, S, Vh
def norm(
inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
):
r"""Calculates the norm of tensor elements over a given axis.
This function is able to return different matrix norms,
or one of an infinite number of vector norms (described below), depending on the value of the ord parameter.
Args:
inp: input tensor. Should have a numeric data type.
ord: Order of the norm (see table under Notes). If not specified, the default is 2.
axis: Axis along which to compute vector norms.
If axis is an integer, it specifies the axis of inp along which to compute the vector norms.
keepdims: If this is set to ``True``,
the axes which are normed over are left in the result as dimensions with size one.
Returns:
Norm of the matrix or vector(s).
.. note::
Now the following norms can be calculated:
* inf: norm-:math:`\infty` (maximum of absolute values).
* -inf: norm-:math:`-\infty` (minimum of absolute values).
* 2: 2-norm (largest singluar value).
The Frobenius norm is given by to ``sum(abs(x)**ord)**(1./ord)``:
.. math::
\|A\|_F=\left[\sum_{i, j} a b s\left(a_{i, j}\right)^2\right]^{1 / 2}
.. seealso:: :func:`numpy.linalg.norm` / :func:`~.functional.normalize`
Examples:
>>> import math
>>> x = Tensor([1, 2, 3])
>>> F.norm(x, ord=math.inf)
Tensor(3, dtype=int32, device=xpux:0)
>>> F.norm(x, ord=-math.inf)
Tensor(1, dtype=int32, device=xpux:0)
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.norm(x, ord=2, axis=0)
Tensor([4.1231 5.3852 6.7082], device=xpux:0)
>>> F.norm(x, ord=2, axis=1)
Tensor([3.7417 8.775 ], device=xpux:0)
"""
if axis is None:
if inp.ndim != 1:
raise TypeError("axis is required unless input is a vector")
if ord is None:
ord = 2
if ord == 0:
return sum(inp != 0, axis=axis, keepdims=keepdims)
if ord == math.inf:
return max(abs(inp))
if ord == -math.inf:
return min(abs(inp))
return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
def normalize(
inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
) -> Tensor:
r"""Performs :math:`L_p` normalization of input tensor along given axis.
For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`,
each :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
Args:
inp: input tensor.
ord: power of value applied to input tensor.
axis: dimension to reduce.If None, input must be a vector.
eps: a small value to avoid division by zero.
Returns:
normalized output tensor.
seealso:: :func:`numpy.linalg.norm` / :func:`~.functional.norm`
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.normalize(x, ord=2, axis=0)
Tensor([[0.2425 0.3714 0.4472]
[0.9701 0.9285 0.8944]], device=xpux:0)
>>> F.normalize(x, ord=2, axis=1)
Tensor([[0.2673 0.5345 0.8018]
[0.4558 0.5698 0.6838]], device=xpux:0)
"""
if axis is None:
return inp / clip(norm(inp, ord, axis), lower=eps)
else:
return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
r"""Check whether input contains infinite or nan value. r"""Check whether input contains infinite or nan value.
......
...@@ -1126,17 +1126,28 @@ def roll( ...@@ -1126,17 +1126,28 @@ def roll(
def cumsum(inp: Tensor, axis: int): def cumsum(inp: Tensor, axis: int):
r"""Computes the cumulative sum of elements along given axis. r"""Calculates the cumulative sum of tensor elements over a given axis.
Args: Args:
inp: input tensor. inp: input tensor. Should have a numeric data type.
axis: axis along which cumsum is performed. axis: axis along which cumulative sums must be computed.
Returns:
a tensor containing the cumulative sums.
Examples: Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]], "int32")
>>> F.cumsum(x, 1) If :math:`x_i` is ``NaN``, the cumulative sums is ``NaN`` (i.e., ``NaN`` values propagate).
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.cumsum(x, axis = 0)
Tensor([[1 2 3]
[5 7 9]], dtype=int32, device=xpux:0)
>>> F.cumsum(x, axis = 1)
Tensor([[ 1 3 6] Tensor([[ 1 3 6]
[ 4 9 15]], dtype=int32, device=xpux:0) [ 4 9 15]], dtype=int32, device=xpux:0)
""" """
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册