未验证 提交 7c1ff38e 编写于 作者: C Chen Weihang 提交者: GitHub

Enhance add_parameter check for dygraph layer (#26188)

* enhance add parameter check for layer

* add unittest for coverage

* remove uninit test case

* enrich unittest ccase

* trigger ci check
上级 8c48c7da
......@@ -553,7 +553,10 @@ class Layer(core.Layer):
"The name of buffer should be a string, but received {}.".
format(type(name).__name__))
elif '.' in name:
raise KeyError("The name of buffer can not contain \".\"")
raise KeyError(
"The name of buffer can not contain `.`, "
"because when you access the newly added buffer in the "
"form of `self.**.**`, it will cause AttributeError.")
elif name == '':
raise KeyError("The name of buffer can not be empty.")
elif hasattr(self, name) and name not in self._buffers:
......@@ -736,20 +739,38 @@ class Layer(core.Layer):
Returns:
Parameter: the parameter passed in.
"""
if parameter is None:
self._parameters[name] = None
elif not isinstance(parameter, framework.Parameter):
if '_parameters' not in self.__dict__:
raise RuntimeError(
"super(YourLayer, self).__init__() should be called firstly.")
elif not isinstance(name, six.string_types):
raise TypeError(
"parameter assignment requires Parameter or None, but got '{}'"
.format(type(parameter).__name__))
"The name of parameter should be a string, but received {}.".
format(type(name).__name__))
elif '.' in name:
raise KeyError(
"The name of parameter can not contain `.`, "
"because when you access the newly added parameter in the "
"form of `self.**.**`, it will cause AttributeError.")
elif name == '':
raise KeyError("The name of parameter can not be empty.")
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("The parameter '{}' already exists.".format(name))
elif parameter is not None and not isinstance(parameter,
framework.Parameter):
raise TypeError(
"The parameter to be added should be a Parameter, but received {}.".
format(type(parameter).__name__))
else:
if parameter is None:
self._parameters[name] = None
if len(self._loaddict_holder) > 0:
assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
parameter.name)
if len(self._loaddict_holder) > 0:
assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format(
parameter.name)
parameter.set_value(self._loaddict_holder[parameter.name])
parameter.set_value(self._loaddict_holder[parameter.name])
self._parameters[name] = parameter
self._parameters[name] = parameter
return parameter
def __getattr__(self, name):
......
......@@ -86,6 +86,31 @@ class TestBaseLayer(unittest.TestCase):
ret = l()
self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2])))
def test_add_parameter_with_error(self):
with fluid.dygraph.guard():
net = fluid.Layer()
param = net.create_parameter(shape=[1])
with self.assertRaises(TypeError):
net.add_parameter(10, param)
with self.assertRaises(KeyError):
net.add_parameter("param.name", param)
with self.assertRaises(KeyError):
net.add_parameter("", param)
with self.assertRaises(KeyError):
net.test_param = 10
net.add_parameter("test_param", param)
with self.assertRaises(TypeError):
net.add_parameter("no_param", 10)
load_param = net.create_parameter(shape=[1])
net._loaddict_holder[load_param.name] = load_param
net.add_parameter("load_param", load_param)
class BufferLayer(fluid.Layer):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册