From 7d4e21588d13c7e31e9b21d3adf5db0a318dc591 Mon Sep 17 00:00:00 2001 From: LutaoChu <30695251+LutaoChu@users.noreply.github.com> Date: Fri, 28 Aug 2020 11:54:15 +0800 Subject: [PATCH] add parameters check in static mode for diag op add parameters check in static mode for diag op --- python/paddle/fluid/tests/unittests/test_diag.py | 8 ++++++++ python/paddle/tensor/creation.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_diag.py b/python/paddle/fluid/tests/unittests/test_diag.py index 8bf40459902..780d57b5331 100644 --- a/python/paddle/fluid/tests/unittests/test_diag.py +++ b/python/paddle/fluid/tests/unittests/test_diag.py @@ -83,6 +83,14 @@ class TestDiagV2Error(unittest.TestCase): self.assertRaises(TypeError, test_diag_v2_type) + x = paddle.static.data('data', [3, 3]) + self.assertRaises(TypeError, paddle.diag, x, offset=2.5) + + self.assertRaises(TypeError, paddle.diag, x, padding_value=[9]) + + x = paddle.static.data('data2', [3, 3, 3]) + self.assertRaises(ValueError, paddle.diag, x) + class TestDiagV2API(unittest.TestCase): def setUp(self): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 1911d8ccc25..cb3caf0656e 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -978,6 +978,13 @@ def diag(x, offset=0, padding_value=0, name=None): check_type(x, 'x', (Variable), 'diag_v2') check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'], 'diag_v2') + check_type(offset, 'offset', (int), 'diag_v2') + check_type(padding_value, 'padding_value', (int, float), 'diag_v2') + if len(x.shape) != 1 and len(x.shape) != 2: + raise ValueError( + "The dimension of input x must be either 1 or 2, but received {}". + format(len(x.shape))) + helper = LayerHelper("diag_v2", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) -- GitLab