From d47a97db7026d98be0f4580f177767b8a09ff9f0 Mon Sep 17 00:00:00 2001 From: XiangGao Date: Tue, 7 Sep 2021 18:15:28 +0800 Subject: [PATCH] fix trace op stack overflow (#35419) Co-authored-by: root --- python/paddle/tensor/math.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index a828199873..e73c97ee0f 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1657,12 +1657,6 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3] data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5] """ - if in_dygraph_mode(): - return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) - - inputs = {'Input': [x]} - attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2} - def __check_input(input, offset, dim1, dim2): check_dtype(x.dtype, 'Input', ['int32', 'int64', 'float16', 'float32', 'float64'], @@ -1677,11 +1671,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1 axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2 - assert axis1_ < len(input_shape), \ + assert ((0 <= axis1_) and (axis1_ < len(input_shape))), \ "The argument axis1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ % (-(len(input_shape)), len(input_shape) - 1, axis1) - assert axis2_ < len(input_shape), \ + assert ((0 <= axis2_) and (axis2_ < len(input_shape))), \ "The argument axis2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ % (-(len(input_shape)), len(input_shape) - 1, axis2) @@ -1691,6 +1685,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): "But received axis1 = %d, axis2 = %d\n"%(axis1, axis2) __check_input(input, offset, axis1, axis2) + if in_dygraph_mode(): + return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) + + inputs = {'Input': [x]} + attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2} helper = LayerHelper('trace', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) -- GitLab