diff --git a/python/paddle/fluid/tests/unittests/test_diag.py b/python/paddle/fluid/tests/unittests/test_diag.py index 8bf40459902e09f19a5badce62084841a0a23619..780d57b53310bb5f385a131d4ad52dd6f5e695f0 100644 --- a/python/paddle/fluid/tests/unittests/test_diag.py +++ b/python/paddle/fluid/tests/unittests/test_diag.py @@ -83,6 +83,14 @@ class TestDiagV2Error(unittest.TestCase): self.assertRaises(TypeError, test_diag_v2_type) + x = paddle.static.data('data', [3, 3]) + self.assertRaises(TypeError, paddle.diag, x, offset=2.5) + + self.assertRaises(TypeError, paddle.diag, x, padding_value=[9]) + + x = paddle.static.data('data2', [3, 3, 3]) + self.assertRaises(ValueError, paddle.diag, x) + class TestDiagV2API(unittest.TestCase): def setUp(self): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 1911d8ccc25e01ee6419fd26126881304ab61f01..cb3caf0656e8fd4aba905feed92f10238d1fc9d0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -978,6 +978,13 @@ def diag(x, offset=0, padding_value=0, name=None): check_type(x, 'x', (Variable), 'diag_v2') check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'], 'diag_v2') + check_type(offset, 'offset', (int), 'diag_v2') + check_type(padding_value, 'padding_value', (int, float), 'diag_v2') + if len(x.shape) != 1 and len(x.shape) != 2: + raise ValueError( + "The dimension of input x must be either 1 or 2, but received {}". + format(len(x.shape))) + helper = LayerHelper("diag_v2", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype)