diff --git a/parl/framework/algorithm_base.py b/parl/framework/algorithm_base.py index 110976b6954546355d86aec92b1f10fa3663892a..ca8a3a3b14eaf4402e13b273d1aca85370b33fcc 100644 --- a/parl/framework/algorithm_base.py +++ b/parl/framework/algorithm_base.py @@ -14,8 +14,8 @@ import paddle.fluid as fluid import parl.layers as layers -from parl.framework.model_base import Network, Model from abc import ABCMeta, abstractmethod +from parl.framework.model_base import Network, Model __all__ = ['Algorithm'] diff --git a/parl/framework/model_base.py b/parl/framework/model_base.py index 94d9f03d0605e519eb2690b406b7666564623b95..6382c63be881351cb97e9fcfa7336b318e92a225 100644 --- a/parl/framework/model_base.py +++ b/parl/framework/model_base.py @@ -15,8 +15,9 @@ Base class to define an Algorithm. """ +import hashlib +import paddle.fluid as fluid from abc import ABCMeta, abstractmethod -from parl.utils.utils import has_func __all__ = ['Network', 'Model'] @@ -26,36 +27,64 @@ class Network(object): A Network is an unordered set of LayerFuncs or Networks. """ - def sync_paras_to(self, target_net): - assert not target_net is self, "cannot copy between identical networks" - assert isinstance(target_net, Network) - assert self.__class__.__name__ == target_net.__class__.__name__, \ - "must be the same class for para syncing!" - - for attr in self.__dict__: - if not attr in target_net.__dict__: - continue - val = getattr(self, attr) - target_val = getattr(target_net, attr) - - assert type(val) == type(target_val), \ - "[Error]sync_paras_to failed, \ - ensure that the destination model is generated by deep copied from source model" - - ### TODO: sync paras recursively - if has_func(val, 'sync_paras_to'): - val.sync_paras_to(target_val) - elif isinstance(val, tuple) or isinstance(val, list) or isinstance( - val, set): - for v, tv in zip(val, target_val): - v.sync_paras_to(tv) - elif isinstance(val, dict): - for k in val.keys(): - assert k in target_val - val[k].sync_paras_to(target_val[k]) - else: - # for any other type, we do not copy - pass + def sync_params_to(self, target_net, gpu_id=0, decay=0.0): + """ + Args: + target_net: Network object deepcopy from source network + gpu_id: gpu id of target_net + decay: Float. The decay to use. + target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights + """ + args_hash_id = hashlib.md5('{}_{}_{}'.format( + id(target_net), gpu_id, decay).encode('utf-8')).hexdigest() + has_cached = False + try: + if self._cached_id == args_hash_id: + has_cached = True + except AttributeError: + has_cached = False + + if not has_cached: + # Can not run _cached program, need create a new program + self._cached_id = args_hash_id + + assert not target_net is self, "cannot copy between identical networks" + assert isinstance(target_net, Network) + assert self.__class__.__name__ == target_net.__class__.__name__, \ + "must be the same class for para syncing!" + assert (decay >= 0 and decay <= 1) + + # Resolve Circular Imports + from parl.plutils import get_parameter_pairs, fetch_framework_var + + param_pairs = get_parameter_pairs(self, target_net) + + place = fluid.CPUPlace() if gpu_id < 0 \ + else fluid.CUDAPlace(gpu_id) + self._cached_fluid_executor = fluid.Executor(place) + self._cached_sync_params_program = fluid.Program() + + with fluid.program_guard(self._cached_sync_params_program): + for (src_var_name, target_var_name, is_bias) in param_pairs: + src_var = fetch_framework_var(src_var_name, is_bias) + target_var = fetch_framework_var(target_var_name, is_bias) + fluid.layers.assign( + decay * target_var + (1 - decay) * src_var, target_var) + + self._cached_fluid_executor.run(self._cached_sync_params_program) + + @property + def parameter_names(self): + """ param_attr names of all parameters in Network, + only parameter created by parl.layers included + + Returns: + list of string, param_attr names of all parameters + """ + + # Resolve Circular Imports + from parl.plutils import get_parameter_names + return get_parameter_names(self) class Model(Network): @@ -80,7 +109,7 @@ class Model(Network): Note that it's the model structure that is copied from initial actor, parameters in initial model havn't been copied to target model. - To copy parameters, you must explicitly use sync_paras_to function after the program is initialized. + To copy parameters, you must explicitly use sync_params_to function after the program is initialized. """ __metaclass__ = ABCMeta diff --git a/parl/framework/policy_distribution.py b/parl/framework/policy_distribution.py index 471419d5bb071943bc17f10004e280e2530b2ce3..524aa7a460d22dd8f37407b2e59c5f2fa4f562ae 100644 --- a/parl/framework/policy_distribution.py +++ b/parl/framework/policy_distribution.py @@ -13,10 +13,10 @@ # limitations under the License. import parl.layers as layers +from abc import ABCMeta, abstractmethod from paddle.fluid.framework import Variable from parl.layers import common_functions as comf from paddle.fluid.framework import convert_np_dtype_to_dtype_ -from abc import ABCMeta, abstractmethod class PolicyDistribution(object): diff --git a/parl/framework/tests/algorithm_test.py b/parl/framework/tests/algorithm_test.py deleted file mode 100644 index 0e5b05841b1b434538e9565ffa04c8f9a211cdad..0000000000000000000000000000000000000000 --- a/parl/framework/tests/algorithm_test.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -import parl.layers as layers -from parl.framework.model_base import Model -from parl.framework.algorithm_base import Algorithm -from copy import deepcopy -import numpy as np -import unittest -import sys - - -class Value(Model): - def __init__(self, obs_dim, act_dim): - self.obs_dim = obs_dim - self.act_dim = act_dim - - self.fc1 = layers.fc(size=256, act='relu') - self.fc2 = layers.fc(size=128, act='relu') - self.fc3 = layers.fc(size=self.act_dim) - - def value(self, obs): - out = self.fc1(obs) - out = self.fc2(out) - value = self.fc3(out) - return value - - -class QLearning(Algorithm): - def __init__(self, critic_model): - self.critic_model = critic_model - self.target_model = deepcopy(critic_model) - - def define_predict(self, obs): - self.q_value = self.critic_model.value(obs) - self.q_target_value = self.target_model.value(obs) - - -class AlgorithmBaseTest(unittest.TestCase): - def test_sync_paras_in_one_program(self): - critic_model = Value(obs_dim=4, act_dim=1) - dqn = QLearning(critic_model) - pred_program = fluid.Program() - with fluid.program_guard(pred_program): - obs = layers.data(name='obs', shape=[4], dtype='float32') - dqn.define_predict(obs) - place = fluid.CUDAPlace(0) - executor = fluid.Executor(place) - executor.run(fluid.default_startup_program()) - - N = 10 - random_obs = np.random.random(size=(N, 4)).astype('float32') - for i in range(N): - x = np.expand_dims(random_obs[i], axis=0) - outputs = executor.run( - pred_program, - feed={'obs': x}, - fetch_list=[dqn.q_value, dqn.q_target_value]) - self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten()) - critic_model.sync_paras_to(dqn.target_model) - - random_obs = np.random.random(size=(N, 4)).astype('float32') - for i in range(N): - x = np.expand_dims(random_obs[i], axis=0) - outputs = executor.run( - pred_program, - feed={'obs': x}, - fetch_list=[dqn.q_value, dqn.q_target_value]) - self.assertEqual(outputs[0].flatten(), outputs[1].flatten()) - - def test_sync_paras_among_programs(self): - critic_model = Value(obs_dim=4, act_dim=1) - dqn = QLearning(critic_model) - dqn_2 = deepcopy(dqn) - pred_program = fluid.Program() - pred_program_2 = fluid.Program() - with fluid.program_guard(pred_program): - obs = layers.data(name='obs', shape=[4], dtype='float32') - dqn.define_predict(obs) - - # algorithm #2 - with fluid.program_guard(pred_program_2): - obs_2 = layers.data(name='obs_2', shape=[4], dtype='float32') - dqn_2.define_predict(obs_2) - - place = fluid.CUDAPlace(0) - executor = fluid.Executor(place) - executor.run(fluid.default_startup_program()) - - N = 10 - random_obs = np.random.random(size=(N, 4)).astype('float32') - for i in range(N): - x = np.expand_dims(random_obs[i], axis=0) - outputs = executor.run( - pred_program, feed={'obs': x}, fetch_list=[dqn.q_value]) - - outputs_2 = executor.run( - pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value]) - self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten()) - dqn.critic_model.sync_paras_to(dqn_2.critic_model) - - random_obs = np.random.random(size=(N, 4)).astype('float32') - for i in range(N): - x = np.expand_dims(random_obs[i], axis=0) - outputs = executor.run( - pred_program, feed={'obs': x}, fetch_list=[dqn.q_value]) - - outputs_2 = executor.run( - pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value]) - self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten()) - - -if __name__ == '__main__': - unittest.main() diff --git a/parl/framework/tests/model_base_test.py b/parl/framework/tests/model_base_test.py index c75b15c2d66e18eb0763c41f90cdb299b88a2f29..668230e575320b00c7ed0c7eb69c414bb4217976 100644 --- a/parl/framework/tests/model_base_test.py +++ b/parl/framework/tests/model_base_test.py @@ -12,31 +12,405 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import paddle.fluid as fluid import parl.layers as layers -from parl.framework.model_base import Model -from copy import deepcopy import unittest +from copy import deepcopy +from paddle.fluid import ParamAttr +from parl.framework.model_base import Model +from parl.utils import get_gpu_count +from parl.plutils import fetch_value -class Value(Model): - def __init__(self, obs_dim, act_dim): - self.obs_dim = obs_dim - self.act_dim = act_dim +class TestModel(Model): + def __init__(self): + self.fc1 = layers.fc( + size=256, + act=None, + param_attr=ParamAttr(name='fc1.w'), + bias_attr=ParamAttr(name='fc1.b')) + self.fc2 = layers.fc( + size=128, + act=None, + param_attr=ParamAttr(name='fc2.w'), + bias_attr=ParamAttr(name='fc2.b')) + self.fc3 = layers.fc( + size=1, + act=None, + param_attr=ParamAttr(name='fc3.w'), + bias_attr=ParamAttr(name='fc3.b')) - self.fc1 = layers.fc(size=256, act='relu') - self.fc2 = layers.fc(size=128, act='relu') + def predict(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + out = self.fc3(out) + return out class ModelBaseTest(unittest.TestCase): + def setUp(self): + self.model = TestModel() + self.target_model = deepcopy(self.model) + self.target_model2 = deepcopy(self.model) + + gpu_count = get_gpu_count() + if gpu_count > 0: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace + self.executor = fluid.Executor(place) + def test_network_copy(self): - value = Value(obs_dim=2, act_dim=1) - target_value = deepcopy(value) - self.assertNotEqual(value.fc1.param_name, target_value.fc1.param_name) - self.assertNotEqual(value.fc1.bias_name, target_value.fc1.bias_name) + self.assertNotEqual(self.model.fc1.param_name, + self.target_model.fc1.param_name) + self.assertNotEqual(self.model.fc1.bias_name, + self.target_model.fc1.bias_name) + + self.assertNotEqual(self.model.fc2.param_name, + self.target_model.fc2.param_name) + self.assertNotEqual(self.model.fc2.bias_name, + self.target_model.fc2.bias_name) + + self.assertNotEqual(self.model.fc3.param_name, + self.target_model.fc3.param_name) + self.assertNotEqual(self.model.fc3.bias_name, + self.target_model.fc3.bias_name) + + def test_network_copy_with_multi_copy(self): + self.assertNotEqual(self.target_model.fc1.param_name, + self.target_model2.fc1.param_name) + self.assertNotEqual(self.target_model.fc1.bias_name, + self.target_model2.fc1.bias_name) + + self.assertNotEqual(self.target_model.fc2.param_name, + self.target_model2.fc2.param_name) + self.assertNotEqual(self.target_model.fc2.bias_name, + self.target_model2.fc2.bias_name) + + self.assertNotEqual(self.target_model.fc3.param_name, + self.target_model2.fc3.param_name) + self.assertNotEqual(self.target_model.fc3.bias_name, + self.target_model2.fc3.bias_name) + + def test_network_parameter_names(self): + self.assertSetEqual( + set(self.model.parameter_names), + set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b'])) + + def test_sync_params_in_one_program(self): + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + target_model_output = self.target_model.predict(obs) + self.executor.run(fluid.default_startup_program()) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[model_output, target_model_output]) + self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten()) + + self.model.sync_params_to(self.target_model) + + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[model_output, target_model_output]) + self.assertEqual(outputs[0].flatten(), outputs[1].flatten()) + + def test_sync_params_among_programs(self): + pred_program = fluid.Program() + pred_program_2 = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + + # program 2 + with fluid.program_guard(pred_program_2): + obs = layers.data(name='obs', shape=[4], dtype='float32') + target_model_output = self.target_model.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, feed={'obs': x}, fetch_list=[model_output]) + + outputs_2 = self.executor.run( + pred_program_2, + feed={'obs': x}, + fetch_list=[target_model_output]) + self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten()) + + self.model.sync_params_to(self.target_model) + + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, feed={'obs': x}, fetch_list=[model_output]) + + outputs_2 = self.executor.run( + pred_program_2, + feed={'obs': x}, + fetch_list=[target_model_output]) + self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten()) + + def _numpy_update(self, target_model, decay): + model_fc1_w = fetch_value('fc1.w') + model_fc1_b = fetch_value('fc1.b') + model_fc2_w = fetch_value('fc2.w') + model_fc2_b = fetch_value('fc2.b') + model_fc3_w = fetch_value('fc3.w') + model_fc3_b = fetch_value('fc3.b') + + unique_id = target_model.parameter_names[0].split('_')[-1] + target_model_fc1_w = fetch_value( + 'PARL_target_fc1.w_{}'.format(unique_id)) + target_model_fc1_b = fetch_value( + 'PARL_target_fc1.b_{}'.format(unique_id)) + target_model_fc2_w = fetch_value( + 'PARL_target_fc2.w_{}'.format(unique_id)) + target_model_fc2_b = fetch_value( + 'PARL_target_fc2.b_{}'.format(unique_id)) + target_model_fc3_w = fetch_value( + 'PARL_target_fc3.w_{}'.format(unique_id)) + target_model_fc3_b = fetch_value( + 'PARL_target_fc3.b_{}'.format(unique_id)) + + # updated self.target_model parameters value in numpy way + target_model_fc1_w = decay * target_model_fc1_w + ( + 1 - decay) * model_fc1_w + target_model_fc1_b = decay * target_model_fc1_b + ( + 1 - decay) * model_fc1_b + target_model_fc2_w = decay * target_model_fc2_w + ( + 1 - decay) * model_fc2_w + target_model_fc2_b = decay * target_model_fc2_b + ( + 1 - decay) * model_fc2_b + target_model_fc3_w = decay * target_model_fc3_w + ( + 1 - decay) * model_fc3_w + target_model_fc3_b = decay * target_model_fc3_b + ( + 1 - decay) * model_fc3_b + + return (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, target_model_fc3_b) + + def test_sync_params_with_decay(self): + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + target_model_output = self.target_model.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + decay = 0.9 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_params_with_decay_with_multi_sync(self): + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + target_model_output = self.target_model.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + decay = 0.9 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + decay = 0.9 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_params_with_different_decay(self): + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + target_model_output = self.target_model.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + decay = 0.9 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + decay = 0.8 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_params_with_multi_target_model(self): + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + model_output = self.model.predict(obs) + target_model_output = self.target_model.predict(obs) + target_model_output2 = self.target_model2.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + decay = 0.9 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model, decay) + + self.model.sync_params_to(self.target_model, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + decay = 0.8 + # update in numpy way + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = self._numpy_update(self.target_model2, decay) + + self.model.sync_params_to(self.target_model2, decay=decay) + + N = 10 + random_obs = np.random.random(size=(N, 4)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + real_target_outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[target_model_output2])[0] + + # Ideal target output + out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b + out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b + out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b - self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name) - self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name) + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) if __name__ == '__main__': diff --git a/parl/layers/layer_wrappers.py b/parl/layers/layer_wrappers.py index 88fb4437db7bb0a6122396ee8978fef0de2bcbae..338fb06e907e4e86ab626f33d10bb11b0221d2f0 100644 --- a/parl/layers/layer_wrappers.py +++ b/parl/layers/layer_wrappers.py @@ -15,15 +15,16 @@ Wrappers for fluid.layers so that the layers can share parameters conveniently. """ -from paddle.fluid.executor import _fetch_var -import paddle.fluid as fluid -from paddle.fluid.layers import * -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.framework import Variable +import inspect import paddle.fluid.layers as layers import paddle.fluid.unique_name as unique_name +import paddle.fluid as fluid +import six from copy import deepcopy -import inspect +from paddle.fluid.executor import _fetch_var +from paddle.fluid.framework import Variable +from paddle.fluid.layers import * +from paddle.fluid.param_attr import ParamAttr from parl.framework.model_base import Network @@ -62,28 +63,6 @@ class LayerFunc(object): self.param_attr = param_attr self.bias_attr = bias_attr - def sync_paras_to(self, target_layer, gpu_id=0): - """ - Copy the paras from self to a target layer - """ - ## isinstance can handle subclass - assert isinstance(target_layer, LayerFunc) - src_attrs = [self.param_attr, self.bias_attr] - target_attrs = [target_layer.param_attr, target_layer.bias_attr] - - place = fluid.CPUPlace() if gpu_id < 0 \ - else fluid.CUDAPlace(gpu_id) - - for i, attrs in enumerate(zip(src_attrs, target_attrs)): - src_attr, target_attr = attrs - assert (src_attr and target_attr) \ - or (not src_attr and not target_attr) - if not src_attr: - continue - src_var = _fetch_var(src_attr.name) - target_var = _fetch_var(target_attr.name, return_numpy=False) - target_var.set(src_var, place) - def __deepcopy__(self, memo): cls = self.__class__ ## __new__ won't init the class, we need to do that ourselves @@ -92,15 +71,14 @@ class LayerFunc(object): memo[id(self)] = copied ## first copy all content - for k, v in self.__dict__.iteritems(): + for k, v in six.iteritems(self.__dict__): setattr(copied, k, deepcopy(v, memo)) ## then we need to create new para names for self.param_attr and self.bias_attr def create_new_para_name(attr): if attr: assert attr.name, "attr should have a name already!" - ## remove the last number id but keep the name key - name_key = "_".join(attr.name.split("_")[:-1]) + name_key = 'PARL_target_' + attr.name attr.name = unique_name.generate(name_key) create_new_para_name(copied.param_attr) diff --git a/parl/layers/tests/param_name_test.py b/parl/layers/tests/param_name_test.py index 8415cb9903faf6a382dc17976fa5007f32974cf2..f3ea4ebc8eb9cd6c63a0e9d1514a5edb7932e88a 100644 --- a/parl/layers/tests/param_name_test.py +++ b/parl/layers/tests/param_name_test.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import parl.layers as layers +import unittest from parl.framework.model_base import Network diff --git a/parl/layers/tests/param_sharing_test.py b/parl/layers/tests/param_sharing_test.py index e2e346b6139f72df84a98db897f6e4cb43370d52..7eac7bb7f90b4abeaffb75ff36fa34a77e9683c5 100644 --- a/parl/layers/tests/param_sharing_test.py +++ b/parl/layers/tests/param_sharing_test.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import numpy as np +import paddle.fluid as fluid import parl.layers as layers +import unittest from parl.framework.model_base import Network -import paddle.fluid as fluid -import numpy as np class MyNetWork(Network): diff --git a/parl/plutils/__init__.py b/parl/plutils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36082a233b32ab4bc2730250dfd59f323df2afe6 --- /dev/null +++ b/parl/plutils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.plutils.common import * diff --git a/parl/plutils/common.py b/parl/plutils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2b08dc79b68d8b744edf86b5753f9cf49e41c975 --- /dev/null +++ b/parl/plutils/common.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common functions of PARL framework +""" + +import paddle.fluid as fluid +from paddle.fluid.executor import _fetch_var +from parl.layers.layer_wrappers import LayerFunc +from parl.framework.model_base import Network + +__all__ = [ + 'fetch_framework_var', 'fetch_value', 'get_parameter_pairs', + 'get_parameter_names' +] + + +def fetch_framework_var(attr_name, is_bias): + """ Fetch framework variable according given attr_name. + Return a new reusing variable through create_parameter way + + Args: + attr_name: string, attr name of parameter + is_bias: bool, decide whether the parameter is bias + + Returns: + framework_var: framework.Varialbe + """ + + scope = fluid.executor.global_scope() + core_var = scope.find_var(attr_name) + shape = core_var.get_tensor().shape() + framework_var = fluid.layers.create_parameter( + shape=shape, + dtype='float32', + attr=fluid.ParamAttr(name=attr_name), + is_bias=is_bias) + return framework_var + + +def fetch_value(attr_name): + """ Given name of ParamAttr, fetch numpy value of the parameter in global_scope + + Args: + attr_name: ParamAttr name of parameter + + Returns: + numpy.ndarray + """ + return _fetch_var(attr_name, return_numpy=True) + + +def get_parameter_pairs(src, target): + """ Recursively get pairs of parameter names between src and target + + Args: + src: parl.Network/parl.LayerFunc/list/tuple/set/dict + target: parl.Network/parl.LayerFunc/list/tuple/set/dict + + Returns: + param_pairs: list of all tuple(src_param_name, target_param_name, is_bias) + between src and target + """ + + param_pairs = [] + if isinstance(src, Network): + for attr in src.__dict__: + if not attr in target.__dict__: + continue + src_var = getattr(src, attr) + target_var = getattr(target, attr) + param_pairs.extend(get_parameter_pairs(src_var, target_var)) + elif isinstance(src, LayerFunc): + param_pairs.append((src.param_attr.name, target.param_attr.name, + False)) + if src.bias_attr: + param_pairs.append((src.bias_attr.name, target.bias_attr.name, + True)) + elif isinstance(src, tuple) or isinstance(src, list) or isinstance( + src, set): + for src_var, target_var in zip(src, target): + param_pairs.extend(get_parameter_pairs(src_var, target_var)) + elif isinstance(src, dict): + for k in src.keys(): + assert k in target + param_pairs.extend(get_parameter_pairs(src[k], target[k])) + else: + # for any other type, won't be handled + pass + return param_pairs + + +def get_parameter_names(obj): + """ Recursively get parameter names in obj, + mainly used to get parameter names of a parl.Network + + Args: + obj: parl.Network/parl.LayerFunc/list/tuple/set/dict + + Returns: + parameter_names: list of string, all parameter names in obj + """ + + parameter_names = [] + for attr in obj.__dict__: + val = getattr(obj, attr) + if isinstance(val, Network): + parameter_names.extend(get_parameter_names(val)) + elif isinstance(val, LayerFunc): + parameter_names.append(val.param_name) + if val.bias_name is not None: + parameter_names.append(val.bias_name) + elif isinstance(val, tuple) or isinstance(val, list) or isinstance( + val, set): + for x in val: + parameter_names.extend(get_parameter_names(x)) + elif isinstance(val, dict): + for x in list(val.values()): + parameter_names.extend(get_parameter_names(x)) + else: + # for any other type, won't be handled + pass + return parameter_names diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index ffd635d38c1f2a4ec935a823e64fcb5a18c33ea3..51b6b5c064daca1853da60d5779d9c806be9315c 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from parl.utils.utils import * +from parl.utils.gputils import * diff --git a/parl/utils/gputils.py b/parl/utils/gputils.py new file mode 100644 index 0000000000000000000000000000000000000000..60b6b84860ea79b229a3220f6ca413acbcdb8d0c --- /dev/null +++ b/parl/utils/gputils.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from parl.utils import logger + +__all__ = ['get_gpu_count'] + + +def get_gpu_count(): + """ get avaliable gpu count + + Returns: + gpu_count: int + """ + + gpu_count = 0 + + env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None) + if env_cuda_devices is not None: + assert isinstance(env_cuda_devices, str) + gpu_count = len(env_cuda_devices.split(',')) + logger.info( + 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) + else: + try: + gpu_count = str(subprocess.check_output(["nvidia-smi", + "-L"])).count('UUID') + logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count)) + except Exception as e: + logger.warn(e.message) + gpu_count = 0 + return gpu_count diff --git a/parl/utils/logger.py b/parl/utils/logger.py index 63e1e7f5410f18f5ecb73bc7166954ce99b0804c..53a595f57781f94c992a1acea3ffe73e642b429b 100644 --- a/parl/utils/logger.py +++ b/parl/utils/logger.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import errno import logging import os -import errno import os.path -from termcolor import colored import sys +from termcolor import colored __all__ = ['set_dir', 'get_dir', 'set_level']