未验证 提交 c31386ef 编写于 作者: R Ruibiao Chen 提交者: GitHub

Add yaml for eye OP (#41476)

上级 516160a4
......@@ -1724,10 +1724,12 @@ def eye(num_rows,
else:
num_columns = num_rows
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
_current_expected_place())
elif _in_legacy_dygraph():
out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
num_columns)
else:
helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype',
......
......@@ -28,6 +28,7 @@ class TestEyeOp(OpTest):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"
self.inputs = {}
......@@ -39,7 +40,7 @@ class TestEyeOp(OpTest):
self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
class TestEyeOp1(OpTest):
......@@ -47,6 +48,7 @@ class TestEyeOp1(OpTest):
'''
Test eye op with default parameters
'''
self.python_api = paddle.eye
self.op_type = "eye"
self.inputs = {}
......@@ -54,7 +56,7 @@ class TestEyeOp1(OpTest):
self.outputs = {'Out': np.eye(50, dtype=float)}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
class TestEyeOp2(OpTest):
......@@ -62,6 +64,7 @@ class TestEyeOp2(OpTest):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"
self.inputs = {}
......@@ -69,7 +72,7 @@ class TestEyeOp2(OpTest):
self.outputs = {'Out': np.eye(99, 1, dtype=float)}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
class API_TestTensorEye(unittest.TestCase):
......
......@@ -627,6 +627,18 @@
func : expm1
backward : expm1_grad
- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
infer_meta :
func : EyeInferMeta
param : [num_rows, num_columns, dtype]
kernel :
func : eye
param : [num_rows, num_columns, dtype]
data_type : dtype
backend : place
- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册