提交 6596320f 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

fix bug in _get_parameter_names function (#159)

上级 2ddf4c11
......@@ -273,33 +273,33 @@ class Model(ModelBase):
set_value(param_name, weight, is_gpu_available)
def _get_parameter_names(self, obj):
""" Recursively get parameter names in a model and its child attributes.
""" Recursively get parameter names in an object.
Args:
obj (``parl.Model``): an instance of ``Model``
obj (Object): any object
Returns:
parameter_names (list): all parameter names in this model.
parameter_names (list): all parameter names in this object.
"""
parameter_names = []
for attr in sorted(obj.__dict__.keys()):
val = getattr(obj, attr)
if isinstance(val, Model):
if isinstance(obj, Model):
for attr in sorted(obj.__dict__.keys()):
val = getattr(obj, attr)
parameter_names.extend(self._get_parameter_names(val))
elif isinstance(val, LayerFunc):
for attr in val.attr_holder.sorted():
if attr:
parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list):
for x in val:
parameter_names.extend(self._get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(self._get_parameter_names(x))
else:
# for any other type, won't be handled. E.g. set
pass
elif isinstance(obj, LayerFunc):
for attr in obj.attr_holder.sorted():
if attr:
parameter_names.append(attr.name)
elif isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
parameter_names.extend(self._get_parameter_names(x))
elif isinstance(obj, dict):
for x in list(obj.values()):
parameter_names.extend(self._get_parameter_names(x))
else:
# for any other type, won't be handled. E.g. set
pass
return parameter_names
def _get_parameter_pairs(self, src, target):
......
......@@ -84,6 +84,42 @@ class TestModel4(parl.Model):
return out
class TestModel6(parl.Model):
def __init__(self):
self.fc1 = layers.fc(
size=256,
act=None,
param_attr=ParamAttr(name='fc1.w'),
bias_attr=ParamAttr(name='fc1.b'))
self.fc_tuple = (layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc2.w'),
bias_attr=ParamAttr(name='fc2.b')), (layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc3.w'),
bias_attr=ParamAttr(name='fc3.b')), 10), 10)
self.fc_dict = {
'k1':
layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc4.w'),
bias_attr=ParamAttr(name='fc4.b')),
'k2': {
'k22':
layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc5.w'),
bias_attr=ParamAttr(name='fc5.b'))
},
'k3':
1,
}
class ModelBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
......@@ -139,6 +175,14 @@ class ModelBaseTest(unittest.TestCase):
set(self.model.parameters()),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
model2 = TestModel6()
self.assertSetEqual(
set(model2.parameters()),
set([
'fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b', 'fc4.w',
'fc4.b', 'fc5.w', 'fc5.b'
]))
def test_sync_weights_in_one_program(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册