未验证 提交 f41449e7 编写于 作者: S songyouwei 提交者: GitHub

cherry-pick #22418 (#22642)

test=release/1.7, test=develop
上级 1cb19bf1
......@@ -150,15 +150,11 @@ class Layer(core.Layer):
Returns:
list of :ref:`api_guide_Variable_en` : a list of Parameters.
"""
ret = [p for p in self._parameters.values()]
parameters_set = set(ret)
if include_sublayers:
for l in self._sub_layers.values():
for p in l.parameters(include_sublayers):
if p in parameters_set:
continue
parameters_set.add(p)
ret.append(p)
ret = [
param
for _, param in self.named_parameters(
include_sublayers=include_sublayers)
]
return ret
def sublayers(self, include_sublayers=True):
......@@ -170,11 +166,11 @@ class Layer(core.Layer):
Returns:
list of Layer : a list of sub layers.
"""
ret = [l for l in self._sub_layers.values()]
if include_sublayers:
for l in self._sub_layers.values():
for sub_l in l.sublayers(include_sublayers):
ret.append(sub_l)
ret = [
layer
for _, layer in self.named_sublayers(
include_sublayers=include_sublayers)
]
return ret
def named_parameters(self, prefix='', include_sublayers=True):
......@@ -349,7 +345,12 @@ class Layer(core.Layer):
Returns:
Parameter: the parameter passed in.
"""
assert isinstance(parameter, framework.Parameter)
if parameter is None:
self._parameters[name] = None
elif not isinstance(parameter, framework.Parameter):
raise TypeError(
"parameter assignment requires Parameter or None, but got '{}'"
.format(type(parameter).__name__))
if len(self._loaddict_holder) > 0:
assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
......@@ -376,8 +377,8 @@ class Layer(core.Layer):
if isinstance(getattr(type(self), name, None), property):
object.__setattr__(self, name, value)
if isinstance(value, framework.Parameter):
params = self.__dict__.get('_parameters', None)
if isinstance(value, framework.Parameter):
if params is None:
raise ValueError(
"super(YourLayer, self).__init__() should be called first")
......@@ -389,14 +390,28 @@ class Layer(core.Layer):
_remove_if_exist(self.__dict__, self._sub_layers)
params[name] = value
elif isinstance(value, core.Layer):
elif params is not None and name in params:
if value is not None:
raise TypeError(
"assignment to parameter '{}' should be of type Parameter or None, but got '{}'"
.format(name, type(value).__name__))
params[name] = None
else:
layers = self.__dict__.get('_sub_layers', None)
if isinstance(value, core.Layer):
if layers is None:
raise ValueError(
"super(YourLayer, self).__init__() should be called first")
"super(YourLayer, self).__init__() should be called first"
)
_remove_if_exist(self.__dict__, self._parameters)
layers[name] = value
elif layers is not None and name in layers:
if value is not None:
raise TypeError(
"assignment to sublayer '{}' should be of type Layer or None, but got '{}'"
.format(name, type(value).__name__))
layers[name] = None
else:
object.__setattr__(self, name, value)
......
......@@ -497,6 +497,19 @@ class TestImperative(unittest.TestCase):
self.assertTrue(hasattr(layer, "test_attr"))
self.assertEqual(layer.test_attr, 1)
my_layer = MyLayer()
my_layer.w1 = my_layer.create_parameter([3, 3])
my_layer.add_parameter('w2', None)
self.assertEqual(len(my_layer.parameters()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'w1', 'str')
my_layer.w1 = None
self.assertEqual(len(my_layer.parameters()), 0)
my_layer.l1 = fluid.dygraph.Linear(3, 3)
self.assertEqual(len(my_layer.sublayers()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'l1', 'str')
my_layer.l1 = None
self.assertEqual(len(my_layer.sublayers()), 0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册