diff --git a/parl/core/fluid/model.py b/parl/core/fluid/model.py index 7c23eda633bf7487ac577de9d50590c2e15626e9..38d653ad20275d281a1bca4cf63d1198475a8696 100644 --- a/parl/core/fluid/model.py +++ b/parl/core/fluid/model.py @@ -273,33 +273,33 @@ class Model(ModelBase): set_value(param_name, weight, is_gpu_available) def _get_parameter_names(self, obj): - """ Recursively get parameter names in a model and its child attributes. + """ Recursively get parameter names in an object. Args: - obj (``parl.Model``): an instance of ``Model`` + obj (Object): any object Returns: - parameter_names (list): all parameter names in this model. + parameter_names (list): all parameter names in this object. """ parameter_names = [] - for attr in sorted(obj.__dict__.keys()): - val = getattr(obj, attr) - if isinstance(val, Model): + if isinstance(obj, Model): + for attr in sorted(obj.__dict__.keys()): + val = getattr(obj, attr) parameter_names.extend(self._get_parameter_names(val)) - elif isinstance(val, LayerFunc): - for attr in val.attr_holder.sorted(): - if attr: - parameter_names.append(attr.name) - elif isinstance(val, tuple) or isinstance(val, list): - for x in val: - parameter_names.extend(self._get_parameter_names(x)) - elif isinstance(val, dict): - for x in list(val.values()): - parameter_names.extend(self._get_parameter_names(x)) - else: - # for any other type, won't be handled. E.g. set - pass + elif isinstance(obj, LayerFunc): + for attr in obj.attr_holder.sorted(): + if attr: + parameter_names.append(attr.name) + elif isinstance(obj, tuple) or isinstance(obj, list): + for x in obj: + parameter_names.extend(self._get_parameter_names(x)) + elif isinstance(obj, dict): + for x in list(obj.values()): + parameter_names.extend(self._get_parameter_names(x)) + else: + # for any other type, won't be handled. E.g. set + pass return parameter_names def _get_parameter_pairs(self, src, target): diff --git a/parl/core/fluid/tests/model_base_test_.py b/parl/core/fluid/tests/model_base_test_.py index dbab50abaf4218907c1c021ca36f8f90e686ef21..82adbfbe3eb3fe5ce339c6ccef7bac697a795055 100644 --- a/parl/core/fluid/tests/model_base_test_.py +++ b/parl/core/fluid/tests/model_base_test_.py @@ -84,6 +84,42 @@ class TestModel4(parl.Model): return out +class TestModel6(parl.Model): + def __init__(self): + self.fc1 = layers.fc( + size=256, + act=None, + param_attr=ParamAttr(name='fc1.w'), + bias_attr=ParamAttr(name='fc1.b')) + self.fc_tuple = (layers.fc( + size=128, + act=None, + param_attr=ParamAttr(name='fc2.w'), + bias_attr=ParamAttr(name='fc2.b')), (layers.fc( + size=1, + act=None, + param_attr=ParamAttr(name='fc3.w'), + bias_attr=ParamAttr(name='fc3.b')), 10), 10) + self.fc_dict = { + 'k1': + layers.fc( + size=128, + act=None, + param_attr=ParamAttr(name='fc4.w'), + bias_attr=ParamAttr(name='fc4.b')), + 'k2': { + 'k22': + layers.fc( + size=1, + act=None, + param_attr=ParamAttr(name='fc5.w'), + bias_attr=ParamAttr(name='fc5.b')) + }, + 'k3': + 1, + } + + class ModelBaseTest(unittest.TestCase): def setUp(self): self.model = TestModel() @@ -139,6 +175,14 @@ class ModelBaseTest(unittest.TestCase): set(self.model.parameters()), set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b'])) + model2 = TestModel6() + self.assertSetEqual( + set(model2.parameters()), + set([ + 'fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b', 'fc4.w', + 'fc4.b', 'fc5.w', 'fc5.b' + ])) + def test_sync_weights_in_one_program(self): pred_program = fluid.Program() with fluid.program_guard(pred_program):