未验证 提交 d47a97db 编写于 作者: X XiangGao 提交者: GitHub

fix trace op stack overflow (#35419)

Co-authored-by: Nroot <root@bjyz-sys-gpu-kongming9.bjyz.baidu.com>
上级 cec36ea6
......@@ -1657,12 +1657,6 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
"""
if in_dygraph_mode():
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
inputs = {'Input': [x]}
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
def __check_input(input, offset, dim1, dim2):
check_dtype(x.dtype, 'Input',
['int32', 'int64', 'float16', 'float32', 'float64'],
......@@ -1677,11 +1671,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1
axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2
assert axis1_ < len(input_shape), \
assert ((0 <= axis1_) and (axis1_ < len(input_shape))), \
"The argument axis1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, axis1)
assert axis2_ < len(input_shape), \
assert ((0 <= axis2_) and (axis2_ < len(input_shape))), \
"The argument axis2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, axis2)
......@@ -1691,6 +1685,11 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
__check_input(input, offset, axis1, axis2)
if in_dygraph_mode():
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
inputs = {'Input': [x]}
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
helper = LayerHelper('trace', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册