未验证 提交 0a67c2e5 编写于 作者: Z zyfncg 提交者: GitHub

Refactor python api of trace (#45344)

* support selected_rows kernel for multiply in dygraph

* refine the trace python api

* fix check input

* fix check input

* fix check input
上级 12917c8c
......@@ -2821,7 +2821,7 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
"""
def __check_input(input, offset, dim1, dim2):
def __check_input(x, offset, axis1, axis2):
check_dtype(x.dtype, 'Input',
['int32', 'int64', 'float16', 'float32', 'float64'],
'trace')
......@@ -2848,17 +2848,15 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"axis1 and axis2 cannot be the same axis." \
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
__check_input(input, offset, axis1, axis2)
if in_dygraph_mode():
return _C_ops.final_state_trace( x, offset, axis1, axis2 )
if _in_legacy_dygraph():
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
inputs = {'Input': [x]}
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
helper = LayerHelper('trace', **locals())
__check_input(x, offset, axis1, axis2)
helper = LayerHelper('trace', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -2941,7 +2939,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
if _in_legacy_dygraph():
return _C_ops.diagonal(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
def __check_input(input, offset, dim1, dim2):
def __check_input(x, offset, axis1, axis2):
check_dtype(x.dtype, 'Input',
['bool', 'int32', 'int64', 'float16', 'float32', 'float64'],
'diagonal')
......@@ -2967,7 +2965,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
"axis1 and axis2 cannot be the same axis." \
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
__check_input(input, offset, axis1, axis2)
__check_input(x, offset, axis1, axis2)
helper = LayerHelper('diagonal', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册