diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b1f2b6df091ab12ab65aa673d0645f44012c64ae..cf5c58556b7e49d612918409adfc70f561c903fe 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4117,15 +4117,20 @@ class Parameter(Variable): """ def __init__(self, block, shape, dtype, **kwargs): - if shape is None or dtype is None: - raise ValueError("Parameter must set shape and dtype") + if shape is None: + raise ValueError("The shape of Parameter should not be None") + if dtype is None: + raise ValueError("The dtype of Parameter should not be None") + if len(shape) == 0: - raise ValueError("Parameter shape cannot be empty") + raise ValueError( + "The dimensions of shape for Parameter must be greater than 0") for each in shape: if each < 0: - raise ValueError("Parameter shape should not be related with " - "batch-size") + raise ValueError( + "Each dimension of shape for Parameter must be greater than 0, but received %s" + % list(shape)) Variable.__init__( self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index df42e6cb9a050b76099b4a53fdd08d2852284d1f..fc7427dcbfd6998598ad95b70d245f6c8c1b28ae 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -46,6 +46,21 @@ class TestParameter(unittest.TestCase): p = io.get_parameter_value_by_name('fc.w', exe, main_program) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) + def test_exceptions(self): + b = main_program.global_block() + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=None, dtype='float32', initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[1], dtype=None, initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[], dtype='float32', initializer=None) + with self.assertRaises(ValueError): + b.create_parameter( + name='test', shape=[-1], dtype='float32', initializer=None) + if __name__ == '__main__': unittest.main()