未验证 提交 44952ca6 编写于 作者: W wopeizl 提交者: GitHub

cherry-pick enhance API create_parameter test=develop test=release/1.6 (#20291)

上级 8ea490f1
...@@ -4315,15 +4315,20 @@ class Parameter(Variable): ...@@ -4315,15 +4315,20 @@ class Parameter(Variable):
""" """
def __init__(self, block, shape, dtype, **kwargs): def __init__(self, block, shape, dtype, **kwargs):
if shape is None or dtype is None: if shape is None:
raise ValueError("Parameter must set shape and dtype") raise ValueError("The shape of Parameter should not be None")
if dtype is None:
raise ValueError("The dtype of Parameter should not be None")
if len(shape) == 0: if len(shape) == 0:
raise ValueError("Parameter shape cannot be empty") raise ValueError(
"The dimensions of shape for Parameter must be greater than 0")
for each in shape: for each in shape:
if each < 0: if each < 0:
raise ValueError("Parameter shape should not be related with " raise ValueError(
"batch-size") "Each dimension of shape for Parameter must be greater than 0, but received %s"
% list(shape))
Variable.__init__( Variable.__init__(
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) self, block, persistable=True, shape=shape, dtype=dtype, **kwargs)
......
...@@ -46,6 +46,21 @@ class TestParameter(unittest.TestCase): ...@@ -46,6 +46,21 @@ class TestParameter(unittest.TestCase):
p = io.get_parameter_value_by_name('fc.w', exe, main_program) p = io.get_parameter_value_by_name('fc.w', exe, main_program)
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val))
def test_exceptions(self):
b = main_program.global_block()
with self.assertRaises(ValueError):
b.create_parameter(
name='test', shape=None, dtype='float32', initializer=None)
with self.assertRaises(ValueError):
b.create_parameter(
name='test', shape=[1], dtype=None, initializer=None)
with self.assertRaises(ValueError):
b.create_parameter(
name='test', shape=[], dtype='float32', initializer=None)
with self.assertRaises(ValueError):
b.create_parameter(
name='test', shape=[-1], dtype='float32', initializer=None)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册