From e9205c38e88a259d447f6732f79ad2102a46472b Mon Sep 17 00:00:00 2001 From: wopeizl Date: Wed, 9 Oct 2019 14:52:35 +0800 Subject: [PATCH] add more checks to create_parameter test=develop (#20059) * add more checks to create_parameter test=develop --- python/paddle/fluid/framework.py | 15 ++++++++++----- .../fluid/tests/unittests/test_parameter.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b1f2b6df091..cf5c58556b7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4117,15 +4117,20 @@ class Parameter(Variable): """ def __init__(self, block, shape, dtype, **kwargs): - if shape is None or dtype is None: - raise ValueError("Parameter must set shape and dtype") + if shape is None: + raise ValueError("The shape of Parameter should not be None") + if dtype is None: + raise ValueError("The dtype of Parameter should not be None") + if len(shape) == 0: - raise ValueError("Parameter shape cannot be empty") + raise ValueError( + "The dimensions of shape for Parameter must be greater than 0") for each in shape: if each < 0: - raise ValueError("Parameter shape should not be related with " - "batch-size") + raise ValueError( + "Each dimension of shape for Parameter must be greater than 0, but received %s" + % list(shape)) Variable.__init__( self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index df42e6cb9a0..fc7427dcbfd 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -46,6 +46,21 @@ class TestParameter(unittest.TestCase): p = io.get_parameter_value_by_name('fc.w', exe, main_program) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) + def test_exceptions(self): + b = main_program.global_block() + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=None, dtype='float32', initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[1], dtype=None, initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[], dtype='float32', initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[-1], dtype='float32', initializer=None) + if __name__ == '__main__': unittest.main() -- GitLab