From 0a67c2e52330b637e6557f5fcff5f3d2a4336e97 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 24 Aug 2022 14:15:41 +0800 Subject: [PATCH] Refactor python api of trace (#45344) * support selected_rows kernel for multiply in dygraph * refine the trace python api * fix check input * fix check input * fix check input --- python/paddle/tensor/math.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 5035e93e699..0c23240cdc1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2821,7 +2821,7 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3] data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5] """ - def __check_input(input, offset, dim1, dim2): + def __check_input(x, offset, axis1, axis2): check_dtype(x.dtype, 'Input', ['int32', 'int64', 'float16', 'float32', 'float64'], 'trace') @@ -2848,17 +2848,15 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): "axis1 and axis2 cannot be the same axis." \ "But received axis1 = %d, axis2 = %d\n"%(axis1, axis2) - __check_input(input, offset, axis1, axis2) if in_dygraph_mode(): return _C_ops.final_state_trace( x, offset, axis1, axis2 ) if _in_legacy_dygraph(): return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) - inputs = {'Input': [x]} - attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2} - helper = LayerHelper('trace', **locals()) + __check_input(x, offset, axis1, axis2) + helper = LayerHelper('trace', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( @@ -2941,7 +2939,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None): if _in_legacy_dygraph(): return _C_ops.diagonal(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) - def __check_input(input, offset, dim1, dim2): + def __check_input(x, offset, axis1, axis2): check_dtype(x.dtype, 'Input', ['bool', 'int32', 'int64', 'float16', 'float32', 'float64'], 'diagonal') @@ -2967,7 +2965,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None): "axis1 and axis2 cannot be the same axis." \ "But received axis1 = %d, axis2 = %d\n"%(axis1, axis2) - __check_input(input, offset, axis1, axis2) + __check_input(x, offset, axis1, axis2) helper = LayerHelper('diagonal', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) -- GitLab