未验证 提交 d9f0c9f5 编写于 作者: S songyouwei 提交者: GitHub

support set param with None value (#22418)

* support reset param with None value

* add unittest
test=develop

* update unittest
test=develop
上级 19211072
...@@ -150,15 +150,11 @@ class Layer(core.Layer): ...@@ -150,15 +150,11 @@ class Layer(core.Layer):
Returns: Returns:
list of :ref:`api_guide_Variable_en` : a list of Parameters. list of :ref:`api_guide_Variable_en` : a list of Parameters.
""" """
ret = [p for p in self._parameters.values()] ret = [
parameters_set = set(ret) param
if include_sublayers: for _, param in self.named_parameters(
for l in self._sub_layers.values(): include_sublayers=include_sublayers)
for p in l.parameters(include_sublayers): ]
if p in parameters_set:
continue
parameters_set.add(p)
ret.append(p)
return ret return ret
def sublayers(self, include_sublayers=True): def sublayers(self, include_sublayers=True):
...@@ -170,11 +166,11 @@ class Layer(core.Layer): ...@@ -170,11 +166,11 @@ class Layer(core.Layer):
Returns: Returns:
list of Layer : a list of sub layers. list of Layer : a list of sub layers.
""" """
ret = [l for l in self._sub_layers.values()] ret = [
if include_sublayers: layer
for l in self._sub_layers.values(): for _, layer in self.named_sublayers(
for sub_l in l.sublayers(include_sublayers): include_sublayers=include_sublayers)
ret.append(sub_l) ]
return ret return ret
def named_parameters(self, prefix='', include_sublayers=True): def named_parameters(self, prefix='', include_sublayers=True):
...@@ -349,7 +345,12 @@ class Layer(core.Layer): ...@@ -349,7 +345,12 @@ class Layer(core.Layer):
Returns: Returns:
Parameter: the parameter passed in. Parameter: the parameter passed in.
""" """
assert isinstance(parameter, framework.Parameter) if parameter is None:
self._parameters[name] = None
elif not isinstance(parameter, framework.Parameter):
raise TypeError(
"parameter assignment requires Parameter or None, but got '{}'"
.format(type(parameter).__name__))
if len(self._loaddict_holder) > 0: if len(self._loaddict_holder) > 0:
assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format( assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
...@@ -376,8 +377,8 @@ class Layer(core.Layer): ...@@ -376,8 +377,8 @@ class Layer(core.Layer):
if isinstance(getattr(type(self), name, None), property): if isinstance(getattr(type(self), name, None), property):
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
params = self.__dict__.get('_parameters', None)
if isinstance(value, framework.Parameter): if isinstance(value, framework.Parameter):
params = self.__dict__.get('_parameters', None)
if params is None: if params is None:
raise ValueError( raise ValueError(
"super(YourLayer, self).__init__() should be called first") "super(YourLayer, self).__init__() should be called first")
...@@ -389,16 +390,30 @@ class Layer(core.Layer): ...@@ -389,16 +390,30 @@ class Layer(core.Layer):
_remove_if_exist(self.__dict__, self._sub_layers) _remove_if_exist(self.__dict__, self._sub_layers)
params[name] = value params[name] = value
elif isinstance(value, core.Layer): elif params is not None and name in params:
layers = self.__dict__.get('_sub_layers', None) if value is not None:
if layers is None: raise TypeError(
raise ValueError( "assignment to parameter '{}' should be of type Parameter or None, but got '{}'"
"super(YourLayer, self).__init__() should be called first") .format(name, type(value).__name__))
params[name] = None
_remove_if_exist(self.__dict__, self._parameters)
layers[name] = value
else: else:
object.__setattr__(self, name, value) layers = self.__dict__.get('_sub_layers', None)
if isinstance(value, core.Layer):
if layers is None:
raise ValueError(
"super(YourLayer, self).__init__() should be called first"
)
_remove_if_exist(self.__dict__, self._parameters)
layers[name] = value
elif layers is not None and name in layers:
if value is not None:
raise TypeError(
"assignment to sublayer '{}' should be of type Layer or None, but got '{}'"
.format(name, type(value).__name__))
layers[name] = None
else:
object.__setattr__(self, name, value)
def __delattr__(self, name): def __delattr__(self, name):
if name in self._parameters: if name in self._parameters:
......
...@@ -497,6 +497,19 @@ class TestImperative(unittest.TestCase): ...@@ -497,6 +497,19 @@ class TestImperative(unittest.TestCase):
self.assertTrue(hasattr(layer, "test_attr")) self.assertTrue(hasattr(layer, "test_attr"))
self.assertEqual(layer.test_attr, 1) self.assertEqual(layer.test_attr, 1)
my_layer = MyLayer()
my_layer.w1 = my_layer.create_parameter([3, 3])
my_layer.add_parameter('w2', None)
self.assertEqual(len(my_layer.parameters()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'w1', 'str')
my_layer.w1 = None
self.assertEqual(len(my_layer.parameters()), 0)
my_layer.l1 = fluid.dygraph.Linear(3, 3)
self.assertEqual(len(my_layer.sublayers()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'l1', 'str')
my_layer.l1 = None
self.assertEqual(len(my_layer.sublayers()), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册