提交 6596320f 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

fix bug in _get_parameter_names function (#159)

上级 2ddf4c11
...@@ -273,33 +273,33 @@ class Model(ModelBase): ...@@ -273,33 +273,33 @@ class Model(ModelBase):
set_value(param_name, weight, is_gpu_available) set_value(param_name, weight, is_gpu_available)
def _get_parameter_names(self, obj): def _get_parameter_names(self, obj):
""" Recursively get parameter names in a model and its child attributes. """ Recursively get parameter names in an object.
Args: Args:
obj (``parl.Model``): an instance of ``Model`` obj (Object): any object
Returns: Returns:
parameter_names (list): all parameter names in this model. parameter_names (list): all parameter names in this object.
""" """
parameter_names = [] parameter_names = []
for attr in sorted(obj.__dict__.keys()): if isinstance(obj, Model):
val = getattr(obj, attr) for attr in sorted(obj.__dict__.keys()):
if isinstance(val, Model): val = getattr(obj, attr)
parameter_names.extend(self._get_parameter_names(val)) parameter_names.extend(self._get_parameter_names(val))
elif isinstance(val, LayerFunc): elif isinstance(obj, LayerFunc):
for attr in val.attr_holder.sorted(): for attr in obj.attr_holder.sorted():
if attr: if attr:
parameter_names.append(attr.name) parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list): elif isinstance(obj, tuple) or isinstance(obj, list):
for x in val: for x in obj:
parameter_names.extend(self._get_parameter_names(x)) parameter_names.extend(self._get_parameter_names(x))
elif isinstance(val, dict): elif isinstance(obj, dict):
for x in list(val.values()): for x in list(obj.values()):
parameter_names.extend(self._get_parameter_names(x)) parameter_names.extend(self._get_parameter_names(x))
else: else:
# for any other type, won't be handled. E.g. set # for any other type, won't be handled. E.g. set
pass pass
return parameter_names return parameter_names
def _get_parameter_pairs(self, src, target): def _get_parameter_pairs(self, src, target):
......
...@@ -84,6 +84,42 @@ class TestModel4(parl.Model): ...@@ -84,6 +84,42 @@ class TestModel4(parl.Model):
return out return out
class TestModel6(parl.Model):
def __init__(self):
self.fc1 = layers.fc(
size=256,
act=None,
param_attr=ParamAttr(name='fc1.w'),
bias_attr=ParamAttr(name='fc1.b'))
self.fc_tuple = (layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc2.w'),
bias_attr=ParamAttr(name='fc2.b')), (layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc3.w'),
bias_attr=ParamAttr(name='fc3.b')), 10), 10)
self.fc_dict = {
'k1':
layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc4.w'),
bias_attr=ParamAttr(name='fc4.b')),
'k2': {
'k22':
layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc5.w'),
bias_attr=ParamAttr(name='fc5.b'))
},
'k3':
1,
}
class ModelBaseTest(unittest.TestCase): class ModelBaseTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.model = TestModel() self.model = TestModel()
...@@ -139,6 +175,14 @@ class ModelBaseTest(unittest.TestCase): ...@@ -139,6 +175,14 @@ class ModelBaseTest(unittest.TestCase):
set(self.model.parameters()), set(self.model.parameters()),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b'])) set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
model2 = TestModel6()
self.assertSetEqual(
set(model2.parameters()),
set([
'fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b', 'fc4.w',
'fc4.b', 'fc5.w', 'fc5.b'
]))
def test_sync_weights_in_one_program(self): def test_sync_weights_in_one_program(self):
pred_program = fluid.Program() pred_program = fluid.Program()
with fluid.program_guard(pred_program): with fluid.program_guard(pred_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册