Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
22286579
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
22286579
编写于
9月 16, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(mge/functional): enhance functional statstic func docstring
GitOrigin-RevId: 1045aecf15801d8d7c22bdd691f13038c7d1c2ce
上级
213f4043
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
343 addition
and
141 deletion
+343
-141
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+327
-136
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+16
-5
未找到文件。
imperative/python/megengine/functional/math.py
浏览文件 @
22286579
...
...
@@ -120,24 +120,58 @@ def sum(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the sum of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced.
Default: None
keepdims: whether the output tensor has axis retained or not.
Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which sums must be computed.
By default, the sum must be computed over the entire tensor.
If a sequence of integers, sums must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the sum was computed over the entire tensor, a zero-dimensional tensor containing the sum;
otherwise, a tensor containing the sums.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. admonition:: Special Cases
Let ``N`` equal the number of elements over which to compute the sum.
* If ``N`` is 0, the sum is ``0`` (i.e., the empty sum).
* If :math:`x_i` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).
.. warning::
If the accumulator is too small, overflow occurs:
>>> x = F.ones(128, dtype="int8")
>>> F.sum(x)
Tensor(-128, dtype=int8, device=xpux:0)
Examples:
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
>>> F.sum(x)
Tensor(21, dtype=int32, device=xpux:0)
The sum of an empty tensor is the neutral element 0:
>>> F.sum(Tensor([]))
Tensor(0.0, device=xpux:0)
Normal case:
>>> F.sum(Tensor([1, 2, 3]))
Tensor(6, dtype=int32, device=xpux:0)
>>> F.sum(Tensor([0.5, 1.5]))
Tensor(2.0, device=xpux:0)
Along an axis:
>>> F.sum(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([5 7 9], dtype=int32, device=xpux:0)
>>> F.sum(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([ 6 15], dtype=int32, device=xpux:0)
"""
return
inp
.
sum
(
axis
=
axis
,
keepdims
=
keepdims
)
...
...
@@ -145,22 +179,58 @@ def sum(
def
prod
(
inp
:
Tensor
,
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
=
False
)
->
Tensor
:
r
"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the product of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which products must be computed.
By default, the product must be computed over the entire tensor.
If a sequence of integers, products must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the product was computed over the entire tensor, a zero-dimensional tensor containing the products;
otherwise, a non-zero-dimensional tensor containing the products.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. admonition:: Special Cases
Let ``N`` equal the number of elements over which to compute the product.
* If ``N`` is 0, the product is ``1`` (i.e., the empty product).
* If :math:`x_i` is ``NaN``, the product is ``NaN`` (i.e., ``NaN`` values propagate).
.. warning::
Arithmetic is modular when using integer types, and no error is raised on overflow:
>>> x = Tensor([536870910, 536870910, 536870910, 536870910])
>>> F.prod(x)
Tensor(16, dtype=int32, device=xpux:0)
Examples:
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
>>> F.prod(x)
Tensor(720, dtype=int32, device=xpux:0)
The product of an empty tensor is the neutral element 1:
>>> F.prod(Tensor([]))
Tensor(1.0, device=xpux:0)
Normal case:
>>> F.prod(Tensor([1, 2, 3]))
Tensor(6, dtype=int32, device=xpux:0)
>>> F.prod(Tensor([0.5, 1.5]))
Tensor(0.75, device=xpux:0)
Along an axis:
>>> F.prod(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([ 4 10 18], dtype=int32, device=xpux:0)
>>> F.prod(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([ 6 120], dtype=int32, device=xpux:0)
"""
return
inp
.
prod
(
axis
=
axis
,
keepdims
=
keepdims
)
...
...
@@ -170,24 +240,44 @@ def mean(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the mean value of input tensor along
given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the mean of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which means must be computed.
By default, the mean must be computed over the entire tensor.
If a sequence of integers, means must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the mean was computed over the entire tensor, a zero-dimensional tensor containing the mean;
otherwise, a non-zero-dimensional tensor containing the means.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. admonition:: Special Cases
Let ``N`` equal the number of elements over which to compute the mean.
* If ``N`` is 0, the mean is ``NaN``.
* If :math:`x_i` is ``NaN``, the mean is ``NaN`` (i.e., ``NaN`` values propagate).
Examples:
>>> F.mean(Tensor([1, 2, 3]))
Tensor(2.0, device=xpux:0)
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
>>> out = F.mean(x)
>>> out.numpy()
array(3.5, dtype=float32)
>>> F.mean(Tensor([1, np.nan, 3]))
Tensor(nan, device=xpux:0)
Along an axis:
>>> F.mean(Tensor([[1, 2, 3], [4, 5, 6]]), axis=0)
Tensor([2.5 3.5 4.5], device=xpux:0)
>>> F.mean(Tensor([[1, 2, 3], [4, 5, 6]]), axis=1)
Tensor([2. 5.], device=xpux:0)
"""
return
inp
.
mean
(
axis
=
axis
,
keepdims
=
keepdims
)
...
...
@@ -197,24 +287,36 @@ def var(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the variance value of input tensor along
given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the variance of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which variances must be computed.
By default, the variance must be computed over the entire tensor.
If a sequence of integers, variances must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the variance was computed over the entire tensor, a zero-dimensional tensor containing the variance;
otherwise, a non-zero-dimensional tensor containing the variances.
The returned tensor must have a data type determined by :ref:`dtype-promotion`.
.. note::
The variance is the average of the squared deviations from the mean,
i.e., ``var = mean(x)``, where ``x = abs(a - a.mean())**2``.
Examples:
>>> import numpy as np
>>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
>>> out = F.var(data)
>>> out.numpy().round(decimals=4)
2.9167
>>> x = Tensor([[1, 2], [3, 4]])
>>> F.var(x)
Tensor(1.25, device=xpux:0)
>>> x = Tensor([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]])
>>> F.var(x)
Tensor(6.8333335, device=xpux:0)
"""
if
axis
is
None
:
m
=
mean
(
inp
,
axis
=
axis
,
keepdims
=
False
)
...
...
@@ -229,24 +331,35 @@ def std(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the standard deviation of input tensor along
given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the standard deviation of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which standard deviations must be computed.
By default, the standard deviation must be computed over the entire tensor.
If a sequence of integers, standard deviations must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the standard deviation was computed over the entire tensor, a zero-dimensional tensor containing the standard deviation;
otherwise, a non-zero-dimensional tensor containing the standard deviations.
.. note::
The standard deviation is the square root of the average of the squared deviations from the mean,
i.e., ``std = sqrt(mean(x))``, where ``x = abs(a - a.mean())**2``.
Examples:
>>> import numpy as np
>>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
>>> out = F.std(data, axis=1)
>>> out.numpy().round(decimals=4)
array([0.8165, 0.8165], dtype=float32)
>>> x = Tensor([[1, 2], [3, 4]])
>>> F.std(x)
Tensor(1.118034, device=xpux:0)
>>> x = Tensor([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]])
>>> F.std(x)
Tensor(2.6140645, device=xpux:0)
"""
return
var
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
**
0.5
...
...
@@ -256,23 +369,38 @@ def min(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the min value of input tensor along
given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the minimum of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which minimums must be computed.
By default, the minimum must be computed over the entire tensor.
If a sequence of integers, minimums must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the minimum was computed over the entire tensor, a zero-dimensional tensor containing the minimum;
otherwise, a non-zero-dimensional tensor containing the minimums.
.. admonition:: Special Cases
If :math:`x_i` is ``NaN``, the minimum is ``NaN`` (i.e., ``NaN`` values propagate).
Examples:
>>> import numpy as np
>>> x = Tensor(
np.arange(1, 7, dtype=np.int32).reshape(2,3)
)
>>> x = Tensor(
[[1, 2], [3, 4]]
)
>>> F.min(x)
Tensor(1, dtype=int32, device=xpux:0)
Along an axis:
>>> F.min(x, axis=0)
Tensor([1 2], dtype=int32, device=xpux:0)
>>> F.min(x, axis=1)
Tensor([1 3], dtype=int32, device=xpux:0)
"""
return
inp
.
min
(
axis
=
axis
,
keepdims
=
keepdims
)
...
...
@@ -282,61 +410,42 @@ def max(
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
keepdims
:
bool
=
False
,
)
->
Tensor
:
r
"""Returns the max value of the input tensor along
given axis. If axis is a list of dimensions,
reduce over all of them.
r
"""Calculates the maximum of tensor elements over a given axis (or axes).
Args:
inp: input tensor.
axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
inp: input tensor. Should have a numeric data type.
axis: axis or axes along which maximums must be computed.
By default, the maximum must be computed over the entire tensor.
If a sequence of integers, maximums must be computed over multiple axes.
keepdims: if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions,
and, accordingly, the result must be compatible with the input tensor (see :ref:`broadcasting-rule`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result.
Returns:
output tensor.
if the maximum was computed over the entire tensor, a zero-dimensional tensor containing the maximum;
otherwise, a non-zero-dimensional tensor containing the maximums.
Examples:
>>> import numpy as np
>>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
>>> F.max(x)
Tensor(6, dtype=int32, device=xpux:0)
"""
return
inp
.
max
(
axis
=
axis
,
keepdims
=
keepdims
)
.. admonition:: Special Cases
If :math:`x_i` is ``NaN``, the maximum is ``NaN`` (i.e., ``NaN`` values propagate).
def
norm
(
inp
:
Tensor
,
ord
:
float
=
None
,
axis
:
int
=
None
,
keepdims
=
False
,
):
r
"""Calculates ``p``-norm of input tensor along
given axis.
Examples:
Args:
inp: input tensor.
ord: power of value applied to inp. Default: 2
axis: dimension to reduce. If None, input must be a vector. Default: None
keepdims: whether the output tensor has axis retained or not. Default: False
>>> x = Tensor([[1, 2], [3, 4]])
>>> F.max(x)
Tensor(4, dtype=int32, device=xpux:0)
Returns:
output tensor.
Along an axis:
Examples:
>>> import numpy as np
>>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
>>> out = F.norm(x)
>>> out.numpy().round(decimals=4)
4.3589
>>> F.max(x, axis=0)
Tensor([3 4], dtype=int32, device=xpux:0)
>>> F.max(x, axis=1)
Tensor([2 4], dtype=int32, device=xpux:0)
"""
if
axis
is
None
:
if
inp
.
ndim
!=
1
:
raise
TypeError
(
"axis is required unless input is a vector"
)
if
ord
is
None
:
ord
=
2
if
ord
==
0
:
return
sum
(
inp
!=
0
,
axis
=
axis
,
keepdims
=
keepdims
)
if
ord
==
math
.
inf
:
return
max
(
abs
(
inp
))
if
ord
==
-
math
.
inf
:
return
min
(
abs
(
inp
))
return
sum
(
abs
(
inp
)
**
ord
,
axis
=
axis
,
keepdims
=
keepdims
)
**
(
1.0
/
ord
)
return
inp
.
max
(
axis
=
axis
,
keepdims
=
keepdims
)
# searching functions
def
argmin
(
...
...
@@ -433,31 +542,7 @@ def argmax(
return
result
def
normalize
(
inp
:
Tensor
,
ord
:
float
=
None
,
axis
:
int
=
None
,
eps
:
float
=
1e-12
,
)
->
Tensor
:
r
"""Performs :math:`L_p` normalization of input tensor along
given axis.
For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
Args:
inp: input tensor.
ord: power of value applied to input tensor. Default: 2
axis: dimension to reduce.If None, input must be a vector. Default: None
eps: a small value to avoid division by zero. Default: 1e-12
Returns:
normalized output tensor.
"""
if
axis
is
None
:
return
inp
/
clip
(
norm
(
inp
,
ord
,
axis
),
lower
=
eps
)
else
:
return
inp
/
clip
(
norm
(
inp
,
ord
,
axis
,
keepdims
=
True
),
lower
=
eps
)
# sorting functions
def
argsort
(
inp
:
Tensor
,
descending
:
bool
=
False
)
->
Tensor
:
...
...
@@ -589,6 +674,9 @@ def topk(
return
tns
,
ind
# linear algebra functions
def
matinv
(
inp
:
Tensor
)
->
Tensor
:
r
"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
...
...
@@ -725,6 +813,109 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return
U
,
S
,
Vh
def
norm
(
inp
:
Tensor
,
ord
:
float
=
None
,
axis
:
int
=
None
,
keepdims
=
False
,
):
r
"""Calculates the norm of tensor elements over a given axis.
This function is able to return different matrix norms,
or one of an infinite number of vector norms (described below), depending on the value of the ord parameter.
Args:
inp: input tensor. Should have a numeric data type.
ord: Order of the norm (see table under Notes). If not specified, the default is 2.
axis: Axis along which to compute vector norms.
If axis is an integer, it specifies the axis of inp along which to compute the vector norms.
keepdims: If this is set to ``True``,
the axes which are normed over are left in the result as dimensions with size one.
Returns:
Norm of the matrix or vector(s).
.. note::
Now the following norms can be calculated:
* inf: norm-:math:`\infty` (maximum of absolute values).
* -inf: norm-:math:`-\infty` (minimum of absolute values).
* 2: 2-norm (largest singluar value).
The Frobenius norm is given by to ``sum(abs(x)**ord)**(1./ord)``:
.. math::
\|A\|_F=\left[\sum_{i, j} a b s\left(a_{i, j}\right)^2\right]^{1 / 2}
.. seealso:: :func:`numpy.linalg.norm` / :func:`~.functional.normalize`
Examples:
>>> import math
>>> x = Tensor([1, 2, 3])
>>> F.norm(x, ord=math.inf)
Tensor(3, dtype=int32, device=xpux:0)
>>> F.norm(x, ord=-math.inf)
Tensor(1, dtype=int32, device=xpux:0)
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.norm(x, ord=2, axis=0)
Tensor([4.1231 5.3852 6.7082], device=xpux:0)
>>> F.norm(x, ord=2, axis=1)
Tensor([3.7417 8.775 ], device=xpux:0)
"""
if
axis
is
None
:
if
inp
.
ndim
!=
1
:
raise
TypeError
(
"axis is required unless input is a vector"
)
if
ord
is
None
:
ord
=
2
if
ord
==
0
:
return
sum
(
inp
!=
0
,
axis
=
axis
,
keepdims
=
keepdims
)
if
ord
==
math
.
inf
:
return
max
(
abs
(
inp
))
if
ord
==
-
math
.
inf
:
return
min
(
abs
(
inp
))
return
sum
(
abs
(
inp
)
**
ord
,
axis
=
axis
,
keepdims
=
keepdims
)
**
(
1.0
/
ord
)
def
normalize
(
inp
:
Tensor
,
ord
:
float
=
None
,
axis
:
int
=
None
,
eps
:
float
=
1e-12
,
)
->
Tensor
:
r
"""Performs :math:`L_p` normalization of input tensor along given axis.
For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`,
each :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
Args:
inp: input tensor.
ord: power of value applied to input tensor.
axis: dimension to reduce.If None, input must be a vector.
eps: a small value to avoid division by zero.
Returns:
normalized output tensor.
seealso:: :func:`numpy.linalg.norm` / :func:`~.functional.norm`
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.normalize(x, ord=2, axis=0)
Tensor([[0.2425 0.3714 0.4472]
[0.9701 0.9285 0.8944]], device=xpux:0)
>>> F.normalize(x, ord=2, axis=1)
Tensor([[0.2673 0.5345 0.8018]
[0.4558 0.5698 0.6838]], device=xpux:0)
"""
if
axis
is
None
:
return
inp
/
clip
(
norm
(
inp
,
ord
,
axis
),
lower
=
eps
)
else
:
return
inp
/
clip
(
norm
(
inp
,
ord
,
axis
,
keepdims
=
True
),
lower
=
eps
)
def
_check_non_finite
(
inps
:
Iterable
[
Tensor
],
scale
=
1.0
)
->
Tensor
:
r
"""Check whether input contains infinite or nan value.
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
22286579
...
...
@@ -1126,17 +1126,28 @@ def roll(
def
cumsum
(
inp
:
Tensor
,
axis
:
int
):
r
"""C
omputes the cumulative sum of elements along
given axis.
r
"""C
alculates the cumulative sum of tensor elements over a
given axis.
Args:
inp: input tensor.
axis: axis along which cumsum is performed.
inp: input tensor. Should have a numeric data type.
axis: axis along which cumulative sums must be computed.
Returns:
a tensor containing the cumulative sums.
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]], "int32")
>>> F.cumsum(x, 1)
If :math:`x_i` is ``NaN``, the cumulative sums is ``NaN`` (i.e., ``NaN`` values propagate).
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6]])
>>> F.cumsum(x, axis = 0)
Tensor([[1 2 3]
[5 7 9]], dtype=int32, device=xpux:0)
>>> F.cumsum(x, axis = 1)
Tensor([[ 1 3 6]
[ 4 9 15]], dtype=int32, device=xpux:0)
"""
assert
isinstance
(
inp
,
Tensor
),
"input of cumsum must be type of Tensor"
op
=
builtin
.
Cumsum
(
axis
=
axis
,
exclusive
=
False
,
reverse
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录