From 0a068653bd988a2dac72dd53967f30e36b6a14b6 Mon Sep 17 00:00:00 2001 From: LI Yunxiang <39279048+Banmahhhh@users.noreply.github.com> Date: Mon, 27 Apr 2020 10:10:26 +0800 Subject: [PATCH] remove version 1.3 warnings (#252) * remove version 1.3 warnings * update * yapf * add algorithms test * Update algs_test.py * Update algs_test.py add SAC DDPG TD3 tests * yapf --- docs/new_alg.rst | 1 - examples/offline-Q-learning/dqn.py | 36 +- parl/algorithms/fluid/a3c.py | 14 +- parl/algorithms/fluid/ddpg.py | 53 +- parl/algorithms/fluid/ddqn.py | 7 +- parl/algorithms/fluid/dqn.py | 41 +- parl/algorithms/fluid/impala/impala.py | 33 +- parl/algorithms/fluid/policy_gradient.py | 29 +- parl/algorithms/fluid/ppo.py | 86 +-- parl/algorithms/fluid/tests/algs_test.py | 699 +++++++++++++++++++++++ parl/algorithms/torch/a2c.py | 2 +- parl/core/fluid/agent.py | 30 +- parl/core/fluid/algorithm.py | 41 +- parl/core/fluid/model.py | 60 -- parl/framework/__init__.py | 25 - parl/framework/agent_base.py | 24 - parl/framework/algorithm_base.py | 24 - parl/framework/model_base.py | 24 - parl/framework/policy_distribution.py | 24 - parl/layers/__init__.py | 24 - parl/plutils/__init__.py | 19 - parl/plutils/common.py | 19 - 22 files changed, 753 insertions(+), 562 deletions(-) create mode 100644 parl/algorithms/fluid/tests/algs_test.py delete mode 100644 parl/framework/__init__.py delete mode 100644 parl/framework/agent_base.py delete mode 100644 parl/framework/algorithm_base.py delete mode 100644 parl/framework/model_base.py delete mode 100644 parl/framework/policy_distribution.py delete mode 100644 parl/layers/__init__.py delete mode 100644 parl/plutils/__init__.py delete mode 100644 parl/plutils/common.py diff --git a/docs/new_alg.rst b/docs/new_alg.rst index 973c062..1acf097 100644 --- a/docs/new_alg.rst +++ b/docs/new_alg.rst @@ -59,7 +59,6 @@ Within class ``DQN(Algorithm)``, we define the following methods: Args: model (parl.Model): model defining forward network of Q function - hyperparas (dict): (deprecated) dict of hyper parameters. act_dim (int): dimension of the action space gamma (float): discounted factor for reward computation. lr (float): learning rate. diff --git a/examples/offline-Q-learning/dqn.py b/examples/offline-Q-learning/dqn.py index feedf7d..d761d2f 100644 --- a/examples/offline-Q-learning/dqn.py +++ b/examples/offline-Q-learning/dqn.py @@ -19,23 +19,16 @@ import copy import paddle.fluid as fluid from parl.core.fluid.algorithm import Algorithm from parl.core.fluid import layers -from parl.utils.deprecation import deprecated __all__ = ['DQN'] class DQN(Algorithm): - def __init__(self, - model, - hyperparas=None, - act_dim=None, - gamma=None, - lr=None): + def __init__(self, model, act_dim=None, gamma=None, lr=None): """ DQN algorithm Args: model (parl.Model): model defining forward network of Q function - hyperparas (dict): (deprecated) dict of hyper parameters. act_dim (int): dimension of the action space gamma (float): discounted factor for reward computation. lr (float): learning rate. @@ -43,20 +36,12 @@ class DQN(Algorithm): self.model = model self.target_model = copy.deepcopy(model) - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.act_dim = hyperparas['action_dim'] - self.gamma = hyperparas['gamma'] - else: - assert isinstance(act_dim, int) - assert isinstance(gamma, float) - assert isinstance(lr, float) - self.act_dim = act_dim - self.gamma = gamma - self.lr = lr + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + assert isinstance(lr, float) + self.act_dim = act_dim + self.gamma = gamma + self.lr = lr def predict(self, obs): """ use value model self.model to predict the action value @@ -100,12 +85,7 @@ class DQN(Algorithm): cost = layers.reduce_mean(cost) return cost - def sync_target(self, gpu_id=None): + def sync_target(self): """ sync weights of self.model to self.target_model """ - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) self.model.sync_weights_to(self.target_model) diff --git a/parl/algorithms/fluid/a3c.py b/parl/algorithms/fluid/a3c.py index 9b9f57e..2786eb6 100644 --- a/parl/algorithms/fluid/a3c.py +++ b/parl/algorithms/fluid/a3c.py @@ -24,25 +24,17 @@ __all__ = ['A3C'] class A3C(Algorithm): - def __init__(self, model, hyperparas=None, vf_loss_coeff=None): + def __init__(self, model, vf_loss_coeff=None): """ A3C/A2C algorithm Args: model (parl.Model): forward network of policy and value - hyperparas (dict): (deprecated) dict of hyper parameters. vf_loss_coeff (float): coefficient of the value function loss """ self.model = model - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.A3C` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.vf_loss_coeff = hyperparas['vf_loss_coeff'] - else: - assert isinstance(vf_loss_coeff, (int, float)) - self.vf_loss_coeff = vf_loss_coeff + assert isinstance(vf_loss_coeff, (int, float)) + self.vf_loss_coeff = vf_loss_coeff def learn(self, obs, actions, advantages, target_values, learning_rate, entropy_coeff): diff --git a/parl/algorithms/fluid/ddpg.py b/parl/algorithms/fluid/ddpg.py index c127109..70992ee 100644 --- a/parl/algorithms/fluid/ddpg.py +++ b/parl/algorithms/fluid/ddpg.py @@ -19,7 +19,6 @@ from parl.core.fluid import layers from copy import deepcopy from paddle import fluid from parl.core.fluid.algorithm import Algorithm -from parl.utils.deprecation import deprecated __all__ = ['DDPG'] @@ -27,7 +26,6 @@ __all__ = ['DDPG'] class DDPG(Algorithm): def __init__(self, model, - hyperparas=None, gamma=None, tau=None, actor_lr=None, @@ -37,53 +35,28 @@ class DDPG(Algorithm): Args: model (parl.Model): forward network of actor and critic. The function get_actor_params() of model should be implemented. - hyperparas (dict): (deprecated) dict of hyper parameters. gamma (float): discounted factor for reward computation. tau (float): decay coefficient when updating the weights of self.target_model with self.model actor_lr (float): learning rate of the actor model critic_lr (float): learning rate of the critic model """ - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.gamma = hyperparas['gamma'] - self.tau = hyperparas['tau'] - self.actor_lr = hyperparas['actor_lr'] - self.critic_lr = hyperparas['critic_lr'] - else: - assert isinstance(gamma, float) - assert isinstance(tau, float) - assert isinstance(actor_lr, float) - assert isinstance(critic_lr, float) - self.gamma = gamma - self.tau = tau - self.actor_lr = actor_lr - self.critic_lr = critic_lr + assert isinstance(gamma, float) + assert isinstance(tau, float) + assert isinstance(actor_lr, float) + assert isinstance(critic_lr, float) + self.gamma = gamma + self.tau = tau + self.actor_lr = actor_lr + self.critic_lr = critic_lr self.model = model self.target_model = deepcopy(model) - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='predict') - def define_predict(self, obs): - """ use actor model of self.model to predict the action - """ - return self.predict(obs) - def predict(self, obs): """ use actor model of self.model to predict the action """ return self.model.policy(obs) - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='learn') - def define_learn(self, obs, action, reward, next_obs, terminal): - """ update actor and critic model with DDPG algorithm - """ - return self.learn(obs, action, reward, next_obs, terminal) - def learn(self, obs, action, reward, next_obs, terminal): """ update actor and critic model with DDPG algorithm """ @@ -115,15 +88,7 @@ class DDPG(Algorithm): optimizer.minimize(cost) return cost - def sync_target(self, - gpu_id=None, - decay=None, - share_vars_parallel_executor=None): - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) + def sync_target(self, decay=None, share_vars_parallel_executor=None): if decay is None: decay = 1.0 - self.tau self.model.sync_weights_to( diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py index 5ccd4aa..f81ba60 100644 --- a/parl/algorithms/fluid/ddqn.py +++ b/parl/algorithms/fluid/ddqn.py @@ -85,12 +85,7 @@ class DDQN(Algorithm): optimizer.minimize(cost) return cost - def sync_target(self, gpu_id=None): + def sync_target(self): """ sync weights of self.model to self.target_model """ - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) self.model.sync_weights_to(self.target_model) diff --git a/parl/algorithms/fluid/dqn.py b/parl/algorithms/fluid/dqn.py index e6e9757..ed5f907 100644 --- a/parl/algorithms/fluid/dqn.py +++ b/parl/algorithms/fluid/dqn.py @@ -19,18 +19,16 @@ import copy import paddle.fluid as fluid from parl.core.fluid.algorithm import Algorithm from parl.core.fluid import layers -from parl.utils.deprecation import deprecated __all__ = ['DQN'] class DQN(Algorithm): - def __init__(self, model, hyperparas=None, act_dim=None, gamma=None): + def __init__(self, model, act_dim=None, gamma=None): """ DQN algorithm Args: model (parl.Model): model defining forward network of Q function - hyperparas (dict): (deprecated) dict of hyper parameters. act_dim (int): dimension of the action space gamma (float): discounted factor for reward computation. lr (float): learning rate. @@ -38,38 +36,16 @@ class DQN(Algorithm): self.model = model self.target_model = copy.deepcopy(model) - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.act_dim = hyperparas['action_dim'] - self.gamma = hyperparas['gamma'] - else: - assert isinstance(act_dim, int) - assert isinstance(gamma, float) - self.act_dim = act_dim - self.gamma = gamma - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='predict') - def define_predict(self, obs): - """ use value model self.model to predict the action value - """ - return self.predict(obs) + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + self.act_dim = act_dim + self.gamma = gamma def predict(self, obs): """ use value model self.model to predict the action value """ return self.model.value(obs) - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='learn') - def define_learn(self, obs, action, reward, next_obs, terminal, - learning_rate): - return self.learn(obs, action, reward, next_obs, terminal, - learning_rate) - def learn(self, obs, action, reward, next_obs, terminal, learning_rate): """ update value model self.model with DQN algorithm """ @@ -92,12 +68,7 @@ class DQN(Algorithm): optimizer.minimize(cost) return cost - def sync_target(self, gpu_id=None): + def sync_target(self): """ sync weights of self.model to self.target_model """ - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) self.model.sync_weights_to(self.target_model) diff --git a/parl/algorithms/fluid/impala/impala.py b/parl/algorithms/fluid/impala/impala.py index 025f96f..0007a9c 100644 --- a/parl/algorithms/fluid/impala/impala.py +++ b/parl/algorithms/fluid/impala/impala.py @@ -85,7 +85,6 @@ class VTraceLoss(object): class IMPALA(Algorithm): def __init__(self, model, - hyperparas=None, sample_batch_steps=None, gamma=None, vf_loss_coeff=None, @@ -95,34 +94,22 @@ class IMPALA(Algorithm): Args: model (parl.Model): forward network of policy and value - hyperparas (dict): (deprecated) dict of hyper parameters. sample_batch_steps (int): steps of each environment sampling. gamma (float): discounted factor for reward computation. vf_loss_coeff (float): coefficient of the value function loss. clip_rho_threshold (float): clipping threshold for importance weights (rho). clip_pg_rho_threshold (float): clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). """ - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.IMPALA` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.sample_batch_steps = hyperparas['sample_batch_steps'] - self.gamma = hyperparas['gamma'] - self.vf_loss_coeff = hyperparas['vf_loss_coeff'] - self.clip_rho_threshold = hyperparas['clip_rho_threshold'] - self.clip_pg_rho_threshold = hyperparas['clip_pg_rho_threshold'] - else: - assert isinstance(sample_batch_steps, int) - assert isinstance(gamma, float) - assert isinstance(vf_loss_coeff, float) - assert isinstance(clip_rho_threshold, float) - assert isinstance(clip_pg_rho_threshold, float) - self.sample_batch_steps = sample_batch_steps - self.gamma = gamma - self.vf_loss_coeff = vf_loss_coeff - self.clip_rho_threshold = clip_rho_threshold - self.clip_pg_rho_threshold = clip_pg_rho_threshold + assert isinstance(sample_batch_steps, int) + assert isinstance(gamma, float) + assert isinstance(vf_loss_coeff, float) + assert isinstance(clip_rho_threshold, float) + assert isinstance(clip_pg_rho_threshold, float) + self.sample_batch_steps = sample_batch_steps + self.gamma = gamma + self.vf_loss_coeff = vf_loss_coeff + self.clip_rho_threshold = clip_rho_threshold + self.clip_pg_rho_threshold = clip_pg_rho_threshold self.model = model diff --git a/parl/algorithms/fluid/policy_gradient.py b/parl/algorithms/fluid/policy_gradient.py index b1b901f..d37083f 100644 --- a/parl/algorithms/fluid/policy_gradient.py +++ b/parl/algorithms/fluid/policy_gradient.py @@ -18,51 +18,28 @@ warnings.simplefilter('default') import paddle.fluid as fluid from parl.core.fluid.algorithm import Algorithm from parl.core.fluid import layers -from parl.utils.deprecation import deprecated __all__ = ['PolicyGradient'] class PolicyGradient(Algorithm): - def __init__(self, model, hyperparas=None, lr=None): + def __init__(self, model, lr=None): """ Policy Gradient algorithm Args: model (parl.Model): forward network of the policy. - hyperparas (dict): (deprecated) dict of hyper parameters. lr (float): learning rate of the policy model. """ self.model = model - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.PolicyGradient` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.lr = hyperparas['lr'] - else: - assert isinstance(lr, float) - self.lr = lr - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='predict') - def define_predict(self, obs): - """ use policy model self.model to predict the action probability - """ - return self.predict(obs) + assert isinstance(lr, float) + self.lr = lr def predict(self, obs): """ use policy model self.model to predict the action probability """ return self.model(obs) - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='learn') - def define_learn(self, obs, action, reward): - """ update policy model self.model with policy gradient algorithm - """ - return self.learn(obs, action, reward) - def learn(self, obs, action, reward): """ update policy model self.model with policy gradient algorithm """ diff --git a/parl/algorithms/fluid/ppo.py b/parl/algorithms/fluid/ppo.py index 002ab27..2cd88f4 100644 --- a/parl/algorithms/fluid/ppo.py +++ b/parl/algorithms/fluid/ppo.py @@ -20,7 +20,6 @@ from copy import deepcopy from paddle import fluid from parl.core.fluid import layers from parl.core.fluid.algorithm import Algorithm -from parl.utils.deprecation import deprecated __all__ = ['PPO'] @@ -28,7 +27,6 @@ __all__ = ['PPO'] class PPO(Algorithm): def __init__(self, model, - hyperparas=None, act_dim=None, policy_lr=None, value_lr=None, @@ -37,7 +35,6 @@ class PPO(Algorithm): Args: model (parl.Model): model defining forward network of policy and value. - hyperparas (dict): (deprecated) dict of hyper parameters. act_dim (float): dimension of the action space. policy_lr (float): learning rate of the policy model. value_lr (float): learning rate of the value model. @@ -47,27 +44,14 @@ class PPO(Algorithm): # Used to calculate probability of action in old policy self.old_policy_model = deepcopy(model.policy_model) - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - self.act_dim = hyperparas['act_dim'] - self.policy_lr = hyperparas['policy_lr'] - self.value_lr = hyperparas['value_lr'] - if 'epsilon' in hyperparas: - self.epsilon = hyperparas['epsilon'] - else: - self.epsilon = 0.2 # default - else: - assert isinstance(act_dim, int) - assert isinstance(policy_lr, float) - assert isinstance(value_lr, float) - assert isinstance(epsilon, float) - self.act_dim = act_dim - self.policy_lr = policy_lr - self.value_lr = value_lr - self.epsilon = epsilon + assert isinstance(act_dim, int) + assert isinstance(policy_lr, float) + assert isinstance(value_lr, float) + assert isinstance(epsilon, float) + self.act_dim = act_dim + self.policy_lr = policy_lr + self.value_lr = value_lr + self.epsilon = epsilon def _calc_logprob(self, actions, means, logvars): """ Calculate log probabilities of actions, when given means and logvars @@ -111,49 +95,18 @@ class PPO(Algorithm): log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim) return kl - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='predict') - def define_predict(self, obs): - """ Use policy model of self.model to predict means and logvars of actions - """ - return self.predict(obs) - def predict(self, obs): """ Use the policy model of self.model to predict means and logvars of actions """ means, logvars = self.model.policy(obs) return means - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='sample') - def define_sample(self, obs): - """ Use the policy model of self.model to sample actions - """ - return self.sample(obs) - def sample(self, obs): """ Use the policy model of self.model to sample actions """ sampled_act = self.model.policy_sample(obs) return sampled_act - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='policy_learn') - def define_policy_learn(self, obs, actions, advantages, beta=None): - """ Learn policy model with: - 1. CLIP loss: Clipped Surrogate Objective - 2. KLPEN loss: Adaptive KL Penalty Objective - See: https://arxiv.org/pdf/1707.02286.pdf - - Args: - obs: Tensor, (batch_size, obs_dim) - actions: Tensor, (batch_size, act_dim) - advantages: Tensor (batch_size, ) - beta: Tensor (1) or None - if None, use CLIP Loss; else, use KLPEN loss. - """ - return self.policy_learn(obs, actions, advantages, beta) - def policy_learn(self, obs, actions, advantages, beta=None): """ Learn policy model with: 1. CLIP loss: Clipped Surrogate Objective @@ -196,27 +149,11 @@ class PPO(Algorithm): optimizer.minimize(loss) return loss, kl - @deprecated( - deprecated_in='1.2', - removed_in='1.3', - replace_function='value_predict') - def define_value_predict(self, obs): - """ Use value model of self.model to predict value of obs - """ - return self.value_predict(obs) - def value_predict(self, obs): """ Use value model of self.model to predict value of obs """ return self.model.value(obs) - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='value_learn') - def define_value_learn(self, obs, val): - """ Learn value model with square error cost - """ - return self.value_learn(obs, val) - def value_learn(self, obs, val): """ Learn the value model with square error cost """ @@ -227,12 +164,7 @@ class PPO(Algorithm): optimizer.minimize(loss) return loss - def sync_old_policy(self, gpu_id=None): + def sync_old_policy(self): """ Synchronize weights of self.model.policy_model to self.old_policy_model """ - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `sync_old_policy` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) self.model.policy_model.sync_weights_to(self.old_policy_model) diff --git a/parl/algorithms/fluid/tests/algs_test.py b/parl/algorithms/fluid/tests/algs_test.py new file mode 100644 index 0000000..6d272b8 --- /dev/null +++ b/parl/algorithms/fluid/tests/algs_test.py @@ -0,0 +1,699 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +import parl +from parl import layers + + +class DQNModel(parl.Model): + def __init__(self): + self.fc1 = layers.fc(size=32, act='relu') + self.fc2 = layers.fc(size=2) + + def value(self, obs): + x = self.fc1(obs) + act = self.fc2(x) + return act + + +class DQNAgent(parl.Agent): + def __init__(self, algorithm): + super(DQNAgent, self).__init__(algorithm) + self.alg = algorithm + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.value = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + action = layers.data(name='act', shape=[1], dtype='int32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data(name='next_obs', shape=[4], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.cost = self.alg.learn(obs, action, reward, next_obs, terminal, + lr) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + return act + + def learn(self, obs, act, reward, next_obs, terminal): + lr = 3e-4 + + obs = np.expand_dims(obs, axis=0) + next_obs = np.expand_dims(next_obs, axis=0) + act = np.expand_dims(act, -1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int32'), + 'reward': reward, + 'next_obs': next_obs.astype('float32'), + 'terminal': terminal, + 'lr': np.float32(lr) + } + cost = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.cost])[0] + return cost + + +class A3CModel(parl.Model): + def __init__(self): + self.fc = layers.fc(size=32, act='relu') + + self.policy_fc = layers.fc(size=2) + self.value_fc = layers.fc(size=1) + + def policy(self, obs): + x = self.fc(obs) + policy_logits = self.policy_fc(x) + + return policy_logits + + def value(self, obs): + x = self.fc(obs) + values = self.value_fc(x) + values = layers.squeeze(values, axes=[1]) + + return values + + def policy_and_value(self, obs): + x = self.fc(obs) + policy_logits = self.policy_fc(x) + values = self.value_fc(x) + values = layers.squeeze(values, axes=[1]) + + return policy_logits, values + + +class A3CAgent(parl.Agent): + def __init__(self, algorithm): + super(A3CAgent, self).__init__(algorithm) + self.alg = algorithm + + def build_program(self): + self.predict_program = fluid.Program() + self.value_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.predict_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.predict_actions = self.alg.predict(obs) + + with fluid.program_guard(self.value_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.values = self.alg.value(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + actions = layers.data(name='actions', shape=[], dtype='int64') + advantages = layers.data( + name='advantages', shape=[], dtype='float32') + target_values = layers.data( + name='target_values', shape=[], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) + entropy_coeff = layers.data( + name='entropy_coeff', + shape=[1], + dtype='float32', + append_batch_size=False) + + total_loss, pi_loss, vf_loss, entropy = self.alg.learn( + obs, actions, advantages, target_values, lr, entropy_coeff) + self.learn_outputs = [total_loss, pi_loss, vf_loss, entropy] + + def predict(self, obs_np): + obs_np = obs_np.astype('float32') + + predict_actions = self.fluid_executor.run( + self.predict_program, + feed={'obs': obs_np}, + fetch_list=[self.predict_actions])[0] + return predict_actions + + def value(self, obs_np): + obs_np = obs_np.astype('float32') + + values = self.fluid_executor.run( + self.value_program, feed={'obs': obs_np}, + fetch_list=[self.values])[0] + return values + + def learn(self, obs_np, actions_np, advantages_np, target_values_np): + obs_np = obs_np.astype('float32') + actions_np = actions_np.astype('int64') + advantages_np = advantages_np.astype('float32') + target_values_np = target_values_np.astype('float32') + + lr = 3e-4 + entropy_coeff = 0. + + total_loss, pi_loss, vf_loss, entropy = self.fluid_executor.run( + self.learn_program, + feed={ + 'obs': obs_np, + 'actions': actions_np, + 'advantages': advantages_np, + 'target_values': target_values_np, + 'lr': np.array([lr], dtype='float32'), + 'entropy_coeff': np.array([entropy_coeff], dtype='float32') + }, + fetch_list=self.learn_outputs) + return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff + + +class IMPALAModel(parl.Model): + def __init__(self): + self.fc = layers.fc(size=32, act='relu') + + self.policy_fc = layers.fc(size=2) + self.value_fc = layers.fc(size=1) + + def policy(self, obs): + x = self.fc(obs) + policy_logits = self.policy_fc(x) + + return policy_logits + + def value(self, obs): + x = self.fc(obs) + values = self.value_fc(x) + values = layers.squeeze(values, axes=[1]) + + return values + + +class IMPALAAgent(parl.Agent): + def __init__(self, algorithm): + super(IMPALAAgent, self).__init__(algorithm) + self.alg = algorithm + + def build_program(self): + self.predict_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.predict_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.predict_actions = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + actions = layers.data(name='actions', shape=[], dtype='int64') + behaviour_logits = layers.data( + name='behaviour_logits', shape=[2], dtype='float32') + rewards = layers.data(name='rewards', shape=[], dtype='float32') + dones = layers.data(name='dones', shape=[], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) + entropy_coeff = layers.data( + name='entropy_coeff', + shape=[1], + dtype='float32', + append_batch_size=False) + + vtrace_loss, kl = self.alg.learn(obs, actions, behaviour_logits, + rewards, dones, lr, entropy_coeff) + self.learn_outputs = [ + vtrace_loss.total_loss, vtrace_loss.pi_loss, + vtrace_loss.vf_loss, vtrace_loss.entropy, kl + ] + + def predict(self, obs_np): + obs_np = obs_np.astype('float32') + + predict_actions = self.fluid_executor.run( + self.predict_program, + feed={'obs': obs_np}, + fetch_list=[self.predict_actions])[0] + return predict_actions + + def learn(self, obs, actions, behaviour_logits, rewards, dones, lr, + entropy_coeff): + total_loss, pi_loss, vf_loss, entropy, kl = self.fluid_executor.run( + self.learn_program, + feed={ + 'obs': obs, + 'actions': actions, + 'behaviour_logits': behaviour_logits, + 'rewards': rewards, + 'dones': dones, + 'lr': np.array([lr], dtype='float32'), + 'entropy_coeff': np.array([entropy_coeff], dtype='float32') + }, + fetch_list=self.learn_outputs) + return total_loss, pi_loss, vf_loss, entropy, kl + + +class SACActor(parl.Model): + def __init__(self): + self.mean_linear = layers.fc(size=1) + self.log_std_linear = layers.fc(size=1) + + def policy(self, obs): + means = self.mean_linear(obs) + log_std = self.log_std_linear(obs) + + return means, log_std + + +class SACCritic(parl.Model): + def __init__(self): + self.fc1 = layers.fc(size=1) + self.fc2 = layers.fc(size=1) + + def value(self, obs, act): + concat = layers.concat([obs, act], axis=1) + Q1 = self.fc1(concat) + Q2 = self.fc2(concat) + Q1 = layers.squeeze(Q1, axes=[1]) + Q2 = layers.squeeze(Q2, axes=[1]) + return Q1, Q2 + + +class SACAgent(parl.Agent): + def __init__(self, algorithm): + super(SACAgent, self).__init__(algorithm) + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.sample_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.sample_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.sample_act, _ = self.alg.sample(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + act = layers.data(name='act', shape=[1], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data(name='next_obs', shape=[4], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.critic_cost, self.actor_cost = self.alg.learn( + obs, act, reward, next_obs, terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def sample(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.sample_program, + feed={'obs': obs}, + fetch_list=[self.sample_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + [critic_cost, actor_cost] = self.fluid_executor.run( + self.learn_program, + feed=feed, + fetch_list=[self.critic_cost, self.actor_cost]) + return critic_cost[0], actor_cost[0] + + +class DDPGModel(parl.Model): + def __init__(self): + self.policy_fc = layers.fc(size=1) + self.value_fc = layers.fc(size=1) + + def policy(self, obs): + act = self.policy_fc(obs) + return act + + def value(self, obs, act): + concat = layers.concat([obs, act], axis=1) + Q = self.value_fc(concat) + Q = layers.squeeze(Q, axes=[1]) + return Q + + def get_actor_params(self): + return self.parameters()[:2] + + +class DDPGAgent(parl.Agent): + def __init__(self, algorithm): + super(DDPGAgent, self).__init__(algorithm) + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + act = layers.data(name='act', shape=[1], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data(name='next_obs', shape=[4], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + _, self.critic_cost = self.alg.learn(obs, act, reward, next_obs, + terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0] + self.alg.sync_target() + return critic_cost + + +class TD3Model(parl.Model): + def __init__(self): + self.actor_fc = layers.fc(size=1) + self.q1 = layers.fc(size=1) + self.q2 = layers.fc(size=1) + + def policy(self, obs): + return self.actor_fc(obs) + + def value(self, obs, act): + concat = layers.concat([obs, act], axis=1) + Q1 = self.q1(concat) + Q1 = layers.squeeze(Q1, axes=[1]) + Q2 = self.q2(concat) + Q2 = layers.squeeze(Q2, axes=[1]) + return Q1, Q2 + + def Q1(self, obs, act): + concat = layers.concat([obs, act], axis=1) + Q1 = self.q1(concat) + Q1 = layers.squeeze(Q1, axes=[1]) + return Q1 + + def get_actor_params(self): + return self.parameters()[:2] + + +class TD3Agent(parl.Agent): + def __init__(self, algorithm): + super(TD3Agent, self).__init__(algorithm) + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.actor_learn_program = fluid.Program() + self.critic_learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.actor_learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + self.actor_cost = self.alg.actor_learn(obs) + + with fluid.program_guard(self.critic_learn_program): + obs = layers.data(name='obs', shape=[4], dtype='float32') + act = layers.data(name='act', shape=[1], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data(name='next_obs', shape=[4], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.critic_cost = self.alg.critic_learn(obs, act, reward, + next_obs, terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.fluid_executor.run( + self.critic_learn_program, + feed=feed, + fetch_list=[self.critic_cost])[0] + + actor_cost = self.fluid_executor.run( + self.actor_learn_program, + feed={'obs': obs}, + fetch_list=[self.actor_cost])[0] + self.alg.sync_target() + return actor_cost, critic_cost + + +class PARLtest(unittest.TestCase): + def setUp(self): + # set up DQN test + DQN_model = DQNModel() + DQN_alg = parl.algorithms.DQN(DQN_model, act_dim=2, gamma=0.9) + self.DQN_agent = DQNAgent(DQN_alg) + + # set up A3C test + A3C_model = A3CModel() + A3C_alg = parl.algorithms.A3C(A3C_model, vf_loss_coeff=0.) + self.A3C_agent = A3CAgent(A3C_alg) + + # set up IMPALA test + IMPALA_model = IMPALAModel() + IMPALA_alg = parl.algorithms.IMPALA( + IMPALA_model, + sample_batch_steps=4, + gamma=0.9, + vf_loss_coeff=0., + clip_rho_threshold=1., + clip_pg_rho_threshold=1.) + self.IMPALA_agent = IMPALAAgent(IMPALA_alg) + + # set up SAC test + SAC_actor = SACActor() + SAC_critic = SACCritic() + SAC_alg = parl.algorithms.SAC( + SAC_actor, + SAC_critic, + max_action=1., + gamma=0.99, + tau=0.005, + actor_lr=1e-3, + critic_lr=1e-3) + self.SAC_agent = SACAgent(SAC_alg) + + # set up DDPG test + DDPG_model = DDPGModel() + DDPG_alg = parl.algorithms.DDPG( + DDPG_model, gamma=0.99, tau=0.001, actor_lr=3e-4, critic_lr=3e-4) + self.DDPG_agent = DDPGAgent(DDPG_alg) + + # set up TD3 test + TD3_model = TD3Model() + TD3_alg = parl.algorithms.TD3( + TD3_model, + 1., + gamma=0.99, + tau=0.005, + actor_lr=3e-4, + critic_lr=3e-4) + self.TD3_agent = TD3Agent(TD3_alg) + + def test_DQN_predict(self): + """Test APIs in PARL DQN predict + """ + obs = np.array([-0.02394919, 0.03114079, 0.01136446, 0.00324496]) + + act = self.DQN_agent.predict(obs) + + def test_DQN_learn(self): + """Test APIs in PARL DQN learn + """ + obs = np.array([-0.02394919, 0.03114079, 0.01136446, 0.00324496]) + next_obs = np.array([-0.02332638, -0.16414229, 0.01142936, 0.29949173]) + terminal = np.array([False]).astype('bool') + reward = np.array([1.0]).astype('float32') + act = np.array([0]).astype('int32') + + cost = self.DQN_agent.learn(obs, act, reward, next_obs, terminal) + + def test_A3C_predict(self): + """Test APIs in PARL A3C predict + """ + obs = np.array([-0.02394919, 0.03114079, 0.01136446, 0.00324496]) + obs = np.expand_dims(obs, axis=0) + + logits = self.A3C_agent.predict(obs) + + def test_A3C_value(self): + """Test APIs in PARL A3C predict + """ + obs = np.array([-0.02394919, 0.03114079, 0.01136446, 0.00324496]) + obs = np.expand_dims(obs, axis=0) + + values = self.A3C_agent.value(obs) + + def test_A3C_learn(self): + """Test APIs in PARL A3C learn + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, 0.00324496]]) + action = np.array([0]) + advantages = np.array([-0.02332638]) + target_values = np.array([1.]) + + self.A3C_agent.learn(obs, action, advantages, target_values) + + def test_IMPALA_predict(self): + """Test APIs in PARL IMPALA predict + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, 0.00324496]]) + + policy = self.IMPALA_agent.predict(obs) + + def test_IMPALA_learn(self): + """Test APIs in PARL IMPALA learn + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, 0.00324496], + [-0.02394919, 0.03114079, 0.01136446, 0.00324496], + [-0.02394919, 0.03114079, 0.01136446, 0.00324496], + [-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype('float32') + actions = np.array([1, 1, 1, 1]).astype('int32') + behaviour_logits = np.array([[-1, 1], [-1, 1], [-1, 1], + [-1, 1]]).astype('float32') + rewards = np.array([0, 0, 0, 0]).astype('float32') + dones = np.array([False, False, False, False]).astype('float32') + lr = 3e-4 + entropy_coeff = 0. + + total_loss, pi_loss, vf_loss, entropy, kl = self.IMPALA_agent.learn( + obs, actions, behaviour_logits, rewards, dones, lr, entropy_coeff) + + def test_SAC_predict(self): + """Test APIs in PARL SAC predict + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + act = self.SAC_agent.predict(obs) + + def test_SAC_sample(self): + """Test APIs in PARL SAC sample + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + act = self.SAC_agent.sample(obs) + + def test_SAC_learn(self): + """Test APIs in PARL SAC learn + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + next_obs = np.array( + [[-0.02332638, -0.16414229, 0.01142936, + 0.29949173]]).astype(np.float32) + terminal = np.array([False]).astype('bool') + reward = np.array([1.0]).astype('float32') + act = np.array([[0.]]).astype('float32') + + critic_cost, actor_cost = self.SAC_agent.learn(obs, act, reward, + next_obs, terminal) + + def test_DDPG_predict(self): + """Test APIs in PARL DDPG predict + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + act = self.DDPG_agent.predict(obs) + + def test_DDPG_learn(self): + """Test APIs in PARL DDPG learn + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + next_obs = np.array( + [[-0.02332638, -0.16414229, 0.01142936, + 0.29949173]]).astype(np.float32) + terminal = np.array([False]).astype('bool') + reward = np.array([1.0]).astype('float32') + act = np.array([[0.]]).astype('float32') + + critic_cost, actor_cost = self.SAC_agent.learn(obs, act, reward, + next_obs, terminal) + + def test_TD3_predict(self): + """Test APIs in PARL TD3 predict + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + act = self.TD3_agent.predict(obs) + + def test_TD3_learn(self): + """Test APIs in PARL TD3 learn + """ + obs = np.array([[-0.02394919, 0.03114079, 0.01136446, + 0.00324496]]).astype(np.float32) + next_obs = np.array( + [[-0.02332638, -0.16414229, 0.01142936, + 0.29949173]]).astype(np.float32) + terminal = np.array([False]).astype('bool') + reward = np.array([1.0]).astype('float32') + act = np.array([[0.]]).astype('float32') + + critic_cost, actor_cost = self.TD3_agent.learn(obs, act, reward, + next_obs, terminal) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/algorithms/torch/a2c.py b/parl/algorithms/torch/a2c.py index 3d78ce7..43e3739 100644 --- a/parl/algorithms/torch/a2c.py +++ b/parl/algorithms/torch/a2c.py @@ -27,7 +27,7 @@ __all__ = ['A2C'] class A2C(parl.Algorithm): - def __init__(self, model, config, hyperparas=None): + def __init__(self, model, config): assert isinstance(config['vf_loss_coeff'], (int, float)) self.model = model self.vf_loss_coeff = config['vf_loss_coeff'] diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index 8972443..40cd633 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -17,7 +17,6 @@ warnings.simplefilter('default') import paddle.fluid as fluid from parl.core.fluid import layers -from parl.utils.deprecation import deprecated from parl.core.agent_base import AgentBase from parl.core.fluid.algorithm import Algorithm from parl.utils import machine_info @@ -46,7 +45,6 @@ class Agent(AgentBase): This class will initialize the neural network parameters automatically, and provides an executor for users to run the programs (self.fluid_executor). Attributes: - gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU. fluid_executor (fluid.Executor): executor for running programs of the agent. alg (parl.algorithm): algorithm of this agent. @@ -65,18 +63,12 @@ class Agent(AgentBase): """ - def __init__(self, algorithm, gpu_id=None): + def __init__(self, algorithm): """Build programs by calling the method ``self.build_program()`` and run initialization function of ``fluid.default_startup_program()``. Args: algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`. - gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU. """ - if gpu_id is not None: - warnings.warn( - "the `gpu_id` argument of `__init__` function in `parl.Agent` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) assert isinstance(algorithm, Algorithm) super(Agent, self).__init__(algorithm) @@ -119,26 +111,6 @@ class Agent(AgentBase): """ raise NotImplementedError - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='get_weights') - def get_params(self): - """ Returns a Python dictionary containing the whole parameters of self.alg. - - Returns: - a Python List containing the parameters of self.alg. - """ - return self.algorithm.get_params() - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='set_weights') - def set_params(self, params): - """Copy parameters from ``get_params()`` into this agent. - - Args: - params(dict): a Python List containing the parameters of self.alg. - """ - self.algorithm.set_params(params) - def learn(self, *args, **kwargs): """The training interface for ``Agent``. This function feeds the training data into the learn_program defined in ``build_program()``. diff --git a/parl/core/fluid/algorithm.py b/parl/core/fluid/algorithm.py index 1a05a99..2267e3b 100644 --- a/parl/core/fluid/algorithm.py +++ b/parl/core/fluid/algorithm.py @@ -17,7 +17,6 @@ warnings.simplefilter('default') from parl.core.algorithm_base import AlgorithmBase from parl.core.fluid.model import Model -from parl.utils.deprecation import deprecated __all__ = ['Algorithm'] @@ -57,47 +56,13 @@ class Algorithm(AlgorithmBase): """ - def __init__(self, model=None, hyperparas=None): + def __init__(self, model=None): """ Args: model(``parl.Model``): a neural network that represents a policy or a Q-value function. - hyperparas(dict): a dict storing the hyper-parameters relative to training. """ - if model is not None: - warnings.warn( - "the `model` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - - assert isinstance(model, Model) - self.model = model - if hyperparas is not None: - warnings.warn( - "the `hyperparas` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - - self.hp = hyperparas - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='get_weights') - def get_params(self): - """ Get parameters of self.model. - - Returns: - params(dict): a Python List containing the parameters of self.model. - """ - return self.model.get_params() - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='set_weights') - def set_params(self, params): - """ Set parameters from ``get_params`` to the model. - - Args: - params(dict ): a Python List containing the parameters of self.model. - """ - self.model.set_params(params) + assert isinstance(model, Model) + self.model = model def learn(self, *args, **kwargs): """ Define the loss function and create an optimizer to minize the loss. diff --git a/parl/core/fluid/model.py b/parl/core/fluid/model.py index 38d653a..bf7069a 100644 --- a/parl/core/fluid/model.py +++ b/parl/core/fluid/model.py @@ -17,7 +17,6 @@ import paddle.fluid as fluid from parl.core.fluid.layers.layer_wrappers import LayerFunc from parl.core.fluid.plutils import * from parl.core.model_base import ModelBase -from parl.utils.deprecation import deprecated from parl.utils import machine_info __all__ = ['Model'] @@ -67,30 +66,6 @@ class Model(ModelBase): """ - @deprecated( - deprecated_in='1.2', - removed_in='1.3', - replace_function='sync_weights_to') - def sync_params_to(self, - target_net, - gpu_id=None, - decay=0.0, - share_vars_parallel_executor=None): - """Synchronize parameters in the model to another model (target_net). - - target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights - - Args: - target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model. - decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters. - share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use fluid.ParallelExecutor - to run program instead of fluid.Executor - """ - self.sync_weights_to( - target_model=target_net, - decay=decay, - share_vars_parallel_executor=share_vars_parallel_executor) - def sync_weights_to(self, target_model, decay=0.0, @@ -181,21 +156,6 @@ class Model(ModelBase): else: self._cached_fluid_executor.run(fetch_list=[]) - @property - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='parameters') - def parameter_names(self): - """Get names of all parameters in this ``Model``. - - Only parameters created by ``parl.layers`` are included. - The order of parameter names is consistent among - different instances of the same `Model`. - - Returns: - param_names(list): list of string containing parameter names of all parameters. - """ - return self.parameters() - def parameters(self): """Get names of all parameters in this ``Model``. @@ -223,26 +183,6 @@ class Model(ModelBase): self._parameter_names = self._get_parameter_names(self) return self._parameter_names - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='get_weights') - def get_params(self): - """ Return a Python list containing parameters of current model. - - Returns: - parameters: a Python list containing parameters of the current model. - """ - return self.get_weights() - - @deprecated( - deprecated_in='1.2', removed_in='1.3', replace_function='set_weights') - def set_params(self, params, gpu_id=None): - """Set parameters in the model with params. - - Args: - params (List): List of numpy array . - """ - self.set_weights(weights=params) - def get_weights(self): """Returns a Python list containing parameters of current model. diff --git a/parl/framework/__init__.py b/parl/framework/__init__.py deleted file mode 100644 index 4e48085..0000000 --- a/parl/framework/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "import way `import parl.framework` is deprecated since version 1.2 and will be removed in version 1.3.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.model import * -from parl.core.fluid.algorithm import * -from parl.core.fluid.agent import * diff --git a/parl/framework/agent_base.py b/parl/framework/agent_base.py deleted file mode 100644 index 331f93b..0000000 --- a/parl/framework/agent_base.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "module `parl.framework.agent_base.Agent` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Agent` instead.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.agent import * diff --git a/parl/framework/algorithm_base.py b/parl/framework/algorithm_base.py deleted file mode 100644 index 2499c63..0000000 --- a/parl/framework/algorithm_base.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "module `parl.framework.algorithm_base.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Algorithm` instead.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.algorithm import * diff --git a/parl/framework/model_base.py b/parl/framework/model_base.py deleted file mode 100644 index e4057a7..0000000 --- a/parl/framework/model_base.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "module `parl.framework.model_base.Model` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Model` instead.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.model import * diff --git a/parl/framework/policy_distribution.py b/parl/framework/policy_distribution.py deleted file mode 100644 index 60bd6dd..0000000 --- a/parl/framework/policy_distribution.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "module `parl.framework.policy_distribution` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.policy_distribution` instead.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.policy_distribution import * diff --git a/parl/layers/__init__.py b/parl/layers/__init__.py deleted file mode 100644 index 3283927..0000000 --- a/parl/layers/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -warnings.simplefilter('default') - -warnings.warn( - "import way `import parl.layers` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import layers` or `import parl; parl.layers` instead.", - DeprecationWarning, - stacklevel=2) - -from parl.core.fluid.layers import * diff --git a/parl/plutils/__init__.py b/parl/plutils/__init__.py deleted file mode 100644 index 8bac1d7..0000000 --- a/parl/plutils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -print( - "import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead." -) - -from parl.core.fluid.plutils.common import * diff --git a/parl/plutils/common.py b/parl/plutils/common.py deleted file mode 100644 index 8bac1d7..0000000 --- a/parl/plutils/common.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -print( - "import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead." -) - -from parl.core.fluid.plutils.common import * -- GitLab