未验证 提交 719309a8 编写于 作者: Y yaoxuefeng 提交者: GitHub

add note of large offset of trace test=document_fix (#27693)

上级 cefb49ab
......@@ -1476,8 +1476,7 @@ def clip(x, min=None, max=None, name=None):
def trace(x, offset=0, axis1=0, axis2=1, name=None):
"""
:alias_main: paddle.trace
:alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace
**trace**
This OP computes the sum along diagonals of the input tensor x.
......@@ -1492,32 +1491,26 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
- Note that if offset is out of input's shape indicated by axis1 and axis2, 0 will be returned.
Args:
x(Variable): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns:
Variable: the output data type is the same as input data type.
Tensor: the output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
import numpy as np
case1 = np.random.randn(2, 3).astype('float32')
case2 = np.random.randn(3, 10, 10).astype('float32')
case3 = np.random.randn(3, 10, 5, 10).astype('float32')
paddle.disable_static()
case1 = paddle.to_tensor(case1)
case2 = paddle.to_tensor(case2)
case3 = paddle.to_tensor(case3)
case1 = paddle.randn([2, 3])
case2 = paddle.randn([3, 10, 10])
case3 = paddle.randn([3, 10, 5, 10])
data1 = paddle.trace(case1) # data1.shape = [1]
data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
......@@ -1552,6 +1545,9 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"axis1 and axis2 cannot be the same axis." \
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
if in_dygraph_mode():
return core.ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
if not in_dygraph_mode():
__check_input(input, offset, axis1, axis2)
helper = LayerHelper('trace', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册