未验证 提交 c31386ef 编写于 作者: R Ruibiao Chen 提交者: GitHub

Add yaml for eye OP (#41476)

上级 516160a4
...@@ -1724,10 +1724,12 @@ def eye(num_rows, ...@@ -1724,10 +1724,12 @@ def eye(num_rows,
else: else:
num_columns = num_rows num_columns = num_rows
if _non_static_mode(): if in_dygraph_mode():
out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
_current_expected_place())
elif _in_legacy_dygraph():
out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns', out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
num_columns) num_columns)
else: else:
helper = LayerHelper("eye", **locals()) helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype', check_dtype(dtype, 'dtype',
......
...@@ -28,6 +28,7 @@ class TestEyeOp(OpTest): ...@@ -28,6 +28,7 @@ class TestEyeOp(OpTest):
''' '''
Test eye op with specified shape Test eye op with specified shape
''' '''
self.python_api = paddle.eye
self.op_type = "eye" self.op_type = "eye"
self.inputs = {} self.inputs = {}
...@@ -39,7 +40,7 @@ class TestEyeOp(OpTest): ...@@ -39,7 +40,7 @@ class TestEyeOp(OpTest):
self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)} self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestEyeOp1(OpTest): class TestEyeOp1(OpTest):
...@@ -47,6 +48,7 @@ class TestEyeOp1(OpTest): ...@@ -47,6 +48,7 @@ class TestEyeOp1(OpTest):
''' '''
Test eye op with default parameters Test eye op with default parameters
''' '''
self.python_api = paddle.eye
self.op_type = "eye" self.op_type = "eye"
self.inputs = {} self.inputs = {}
...@@ -54,7 +56,7 @@ class TestEyeOp1(OpTest): ...@@ -54,7 +56,7 @@ class TestEyeOp1(OpTest):
self.outputs = {'Out': np.eye(50, dtype=float)} self.outputs = {'Out': np.eye(50, dtype=float)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestEyeOp2(OpTest): class TestEyeOp2(OpTest):
...@@ -62,6 +64,7 @@ class TestEyeOp2(OpTest): ...@@ -62,6 +64,7 @@ class TestEyeOp2(OpTest):
''' '''
Test eye op with specified shape Test eye op with specified shape
''' '''
self.python_api = paddle.eye
self.op_type = "eye" self.op_type = "eye"
self.inputs = {} self.inputs = {}
...@@ -69,7 +72,7 @@ class TestEyeOp2(OpTest): ...@@ -69,7 +72,7 @@ class TestEyeOp2(OpTest):
self.outputs = {'Out': np.eye(99, 1, dtype=float)} self.outputs = {'Out': np.eye(99, 1, dtype=float)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class API_TestTensorEye(unittest.TestCase): class API_TestTensorEye(unittest.TestCase):
......
...@@ -627,6 +627,18 @@ ...@@ -627,6 +627,18 @@
func : expm1 func : expm1
backward : expm1_grad backward : expm1_grad
- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
infer_meta :
func : EyeInferMeta
param : [num_rows, num_columns, dtype]
kernel :
func : eye
param : [num_rows, num_columns, dtype]
data_type : dtype
backend : place
- api : flatten - api : flatten
args : (Tensor x, int start_axis, int stop_axis) args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape) output : Tensor(out), Tensor(xshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册