未验证 提交 0a068653 编写于 作者: L LI Yunxiang 提交者: GitHub

remove version 1.3 warnings (#252)

* remove version 1.3 warnings

* update

* yapf

* add algorithms test

* Update algs_test.py

* Update algs_test.py

add SAC DDPG TD3 tests

* yapf
上级 91e9814a
...@@ -59,7 +59,6 @@ Within class ``DQN(Algorithm)``, we define the following methods: ...@@ -59,7 +59,6 @@ Within class ``DQN(Algorithm)``, we define the following methods:
Args: Args:
model (parl.Model): model defining forward network of Q function model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
lr (float): learning rate. lr (float): learning rate.
......
...@@ -19,23 +19,16 @@ import copy ...@@ -19,23 +19,16 @@ import copy
import paddle.fluid as fluid import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN'] __all__ = ['DQN']
class DQN(Algorithm): class DQN(Algorithm):
def __init__(self, def __init__(self, model, act_dim=None, gamma=None, lr=None):
model,
hyperparas=None,
act_dim=None,
gamma=None,
lr=None):
""" DQN algorithm """ DQN algorithm
Args: Args:
model (parl.Model): model defining forward network of Q function model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
lr (float): learning rate. lr (float): learning rate.
...@@ -43,20 +36,12 @@ class DQN(Algorithm): ...@@ -43,20 +36,12 @@ class DQN(Algorithm):
self.model = model self.model = model
self.target_model = copy.deepcopy(model) self.target_model = copy.deepcopy(model)
if hyperparas is not None: assert isinstance(act_dim, int)
warnings.warn( assert isinstance(gamma, float)
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", assert isinstance(lr, float)
DeprecationWarning, self.act_dim = act_dim
stacklevel=2) self.gamma = gamma
self.act_dim = hyperparas['action_dim'] self.lr = lr
self.gamma = hyperparas['gamma']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs): def predict(self, obs):
""" use value model self.model to predict the action value """ use value model self.model to predict the action value
...@@ -100,12 +85,7 @@ class DQN(Algorithm): ...@@ -100,12 +85,7 @@ class DQN(Algorithm):
cost = layers.reduce_mean(cost) cost = layers.reduce_mean(cost)
return cost return cost
def sync_target(self, gpu_id=None): def sync_target(self):
""" sync weights of self.model to self.target_model """ sync weights of self.model to self.target_model
""" """
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model) self.model.sync_weights_to(self.target_model)
...@@ -24,25 +24,17 @@ __all__ = ['A3C'] ...@@ -24,25 +24,17 @@ __all__ = ['A3C']
class A3C(Algorithm): class A3C(Algorithm):
def __init__(self, model, hyperparas=None, vf_loss_coeff=None): def __init__(self, model, vf_loss_coeff=None):
""" A3C/A2C algorithm """ A3C/A2C algorithm
Args: Args:
model (parl.Model): forward network of policy and value model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
vf_loss_coeff (float): coefficient of the value function loss vf_loss_coeff (float): coefficient of the value function loss
""" """
self.model = model self.model = model
if hyperparas is not None: assert isinstance(vf_loss_coeff, (int, float))
warnings.warn( self.vf_loss_coeff = vf_loss_coeff
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.A3C` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.vf_loss_coeff = hyperparas['vf_loss_coeff']
else:
assert isinstance(vf_loss_coeff, (int, float))
self.vf_loss_coeff = vf_loss_coeff
def learn(self, obs, actions, advantages, target_values, learning_rate, def learn(self, obs, actions, advantages, target_values, learning_rate,
entropy_coeff): entropy_coeff):
......
...@@ -19,7 +19,6 @@ from parl.core.fluid import layers ...@@ -19,7 +19,6 @@ from parl.core.fluid import layers
from copy import deepcopy from copy import deepcopy
from paddle import fluid from paddle import fluid
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['DDPG'] __all__ = ['DDPG']
...@@ -27,7 +26,6 @@ __all__ = ['DDPG'] ...@@ -27,7 +26,6 @@ __all__ = ['DDPG']
class DDPG(Algorithm): class DDPG(Algorithm):
def __init__(self, def __init__(self,
model, model,
hyperparas=None,
gamma=None, gamma=None,
tau=None, tau=None,
actor_lr=None, actor_lr=None,
...@@ -37,53 +35,28 @@ class DDPG(Algorithm): ...@@ -37,53 +35,28 @@ class DDPG(Algorithm):
Args: Args:
model (parl.Model): forward network of actor and critic. model (parl.Model): forward network of actor and critic.
The function get_actor_params() of model should be implemented. The function get_actor_params() of model should be implemented.
hyperparas (dict): (deprecated) dict of hyper parameters.
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model tau (float): decay coefficient when updating the weights of self.target_model with self.model
actor_lr (float): learning rate of the actor model actor_lr (float): learning rate of the actor model
critic_lr (float): learning rate of the critic model critic_lr (float): learning rate of the critic model
""" """
if hyperparas is not None: assert isinstance(gamma, float)
warnings.warn( assert isinstance(tau, float)
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.", assert isinstance(actor_lr, float)
DeprecationWarning, assert isinstance(critic_lr, float)
stacklevel=2) self.gamma = gamma
self.gamma = hyperparas['gamma'] self.tau = tau
self.tau = hyperparas['tau'] self.actor_lr = actor_lr
self.actor_lr = hyperparas['actor_lr'] self.critic_lr = critic_lr
self.critic_lr = hyperparas['critic_lr']
else:
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.model = model self.model = model
self.target_model = deepcopy(model) self.target_model = deepcopy(model)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.predict(obs)
def predict(self, obs): def predict(self, obs):
""" use actor model of self.model to predict the action """ use actor model of self.model to predict the action
""" """
return self.model.policy(obs) return self.model.policy(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
return self.learn(obs, action, reward, next_obs, terminal)
def learn(self, obs, action, reward, next_obs, terminal): def learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm """ update actor and critic model with DDPG algorithm
""" """
...@@ -115,15 +88,7 @@ class DDPG(Algorithm): ...@@ -115,15 +88,7 @@ class DDPG(Algorithm):
optimizer.minimize(cost) optimizer.minimize(cost)
return cost return cost
def sync_target(self, def sync_target(self, decay=None, share_vars_parallel_executor=None):
gpu_id=None,
decay=None,
share_vars_parallel_executor=None):
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
if decay is None: if decay is None:
decay = 1.0 - self.tau decay = 1.0 - self.tau
self.model.sync_weights_to( self.model.sync_weights_to(
......
...@@ -85,12 +85,7 @@ class DDQN(Algorithm): ...@@ -85,12 +85,7 @@ class DDQN(Algorithm):
optimizer.minimize(cost) optimizer.minimize(cost)
return cost return cost
def sync_target(self, gpu_id=None): def sync_target(self):
""" sync weights of self.model to self.target_model """ sync weights of self.model to self.target_model
""" """
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model) self.model.sync_weights_to(self.target_model)
...@@ -19,18 +19,16 @@ import copy ...@@ -19,18 +19,16 @@ import copy
import paddle.fluid as fluid import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN'] __all__ = ['DQN']
class DQN(Algorithm): class DQN(Algorithm):
def __init__(self, model, hyperparas=None, act_dim=None, gamma=None): def __init__(self, model, act_dim=None, gamma=None):
""" DQN algorithm """ DQN algorithm
Args: Args:
model (parl.Model): model defining forward network of Q function model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
lr (float): learning rate. lr (float): learning rate.
...@@ -38,38 +36,16 @@ class DQN(Algorithm): ...@@ -38,38 +36,16 @@ class DQN(Algorithm):
self.model = model self.model = model
self.target_model = copy.deepcopy(model) self.target_model = copy.deepcopy(model)
if hyperparas is not None: assert isinstance(act_dim, int)
warnings.warn( assert isinstance(gamma, float)
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", self.act_dim = act_dim
DeprecationWarning, self.gamma = gamma
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.predict(obs)
def predict(self, obs): def predict(self, obs):
""" use value model self.model to predict the action value """ use value model self.model to predict the action value
""" """
return self.model.value(obs) return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal,
learning_rate):
return self.learn(obs, action, reward, next_obs, terminal,
learning_rate)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate): def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
""" update value model self.model with DQN algorithm """ update value model self.model with DQN algorithm
""" """
...@@ -92,12 +68,7 @@ class DQN(Algorithm): ...@@ -92,12 +68,7 @@ class DQN(Algorithm):
optimizer.minimize(cost) optimizer.minimize(cost)
return cost return cost
def sync_target(self, gpu_id=None): def sync_target(self):
""" sync weights of self.model to self.target_model """ sync weights of self.model to self.target_model
""" """
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model) self.model.sync_weights_to(self.target_model)
...@@ -85,7 +85,6 @@ class VTraceLoss(object): ...@@ -85,7 +85,6 @@ class VTraceLoss(object):
class IMPALA(Algorithm): class IMPALA(Algorithm):
def __init__(self, def __init__(self,
model, model,
hyperparas=None,
sample_batch_steps=None, sample_batch_steps=None,
gamma=None, gamma=None,
vf_loss_coeff=None, vf_loss_coeff=None,
...@@ -95,34 +94,22 @@ class IMPALA(Algorithm): ...@@ -95,34 +94,22 @@ class IMPALA(Algorithm):
Args: Args:
model (parl.Model): forward network of policy and value model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
sample_batch_steps (int): steps of each environment sampling. sample_batch_steps (int): steps of each environment sampling.
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
vf_loss_coeff (float): coefficient of the value function loss. vf_loss_coeff (float): coefficient of the value function loss.
clip_rho_threshold (float): clipping threshold for importance weights (rho). clip_rho_threshold (float): clipping threshold for importance weights (rho).
clip_pg_rho_threshold (float): clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). clip_pg_rho_threshold (float): clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
""" """
if hyperparas is not None: assert isinstance(sample_batch_steps, int)
warnings.warn( assert isinstance(gamma, float)
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.IMPALA` is deprecated since version 1.2 and will be removed in version 1.3.", assert isinstance(vf_loss_coeff, float)
DeprecationWarning, assert isinstance(clip_rho_threshold, float)
stacklevel=2) assert isinstance(clip_pg_rho_threshold, float)
self.sample_batch_steps = hyperparas['sample_batch_steps'] self.sample_batch_steps = sample_batch_steps
self.gamma = hyperparas['gamma'] self.gamma = gamma
self.vf_loss_coeff = hyperparas['vf_loss_coeff'] self.vf_loss_coeff = vf_loss_coeff
self.clip_rho_threshold = hyperparas['clip_rho_threshold'] self.clip_rho_threshold = clip_rho_threshold
self.clip_pg_rho_threshold = hyperparas['clip_pg_rho_threshold'] self.clip_pg_rho_threshold = clip_pg_rho_threshold
else:
assert isinstance(sample_batch_steps, int)
assert isinstance(gamma, float)
assert isinstance(vf_loss_coeff, float)
assert isinstance(clip_rho_threshold, float)
assert isinstance(clip_pg_rho_threshold, float)
self.sample_batch_steps = sample_batch_steps
self.gamma = gamma
self.vf_loss_coeff = vf_loss_coeff
self.clip_rho_threshold = clip_rho_threshold
self.clip_pg_rho_threshold = clip_pg_rho_threshold
self.model = model self.model = model
......
...@@ -18,51 +18,28 @@ warnings.simplefilter('default') ...@@ -18,51 +18,28 @@ warnings.simplefilter('default')
import paddle.fluid as fluid import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['PolicyGradient'] __all__ = ['PolicyGradient']
class PolicyGradient(Algorithm): class PolicyGradient(Algorithm):
def __init__(self, model, hyperparas=None, lr=None): def __init__(self, model, lr=None):
""" Policy Gradient algorithm """ Policy Gradient algorithm
Args: Args:
model (parl.Model): forward network of the policy. model (parl.Model): forward network of the policy.
hyperparas (dict): (deprecated) dict of hyper parameters.
lr (float): learning rate of the policy model. lr (float): learning rate of the policy model.
""" """
self.model = model self.model = model
if hyperparas is not None: assert isinstance(lr, float)
warnings.warn( self.lr = lr
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PolicyGradient` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.lr = hyperparas['lr']
else:
assert isinstance(lr, float)
self.lr = lr
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use policy model self.model to predict the action probability
"""
return self.predict(obs)
def predict(self, obs): def predict(self, obs):
""" use policy model self.model to predict the action probability """ use policy model self.model to predict the action probability
""" """
return self.model(obs) return self.model(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm
"""
return self.learn(obs, action, reward)
def learn(self, obs, action, reward): def learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm """ update policy model self.model with policy gradient algorithm
""" """
......
...@@ -20,7 +20,6 @@ from copy import deepcopy ...@@ -20,7 +20,6 @@ from copy import deepcopy
from paddle import fluid from paddle import fluid
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['PPO'] __all__ = ['PPO']
...@@ -28,7 +27,6 @@ __all__ = ['PPO'] ...@@ -28,7 +27,6 @@ __all__ = ['PPO']
class PPO(Algorithm): class PPO(Algorithm):
def __init__(self, def __init__(self,
model, model,
hyperparas=None,
act_dim=None, act_dim=None,
policy_lr=None, policy_lr=None,
value_lr=None, value_lr=None,
...@@ -37,7 +35,6 @@ class PPO(Algorithm): ...@@ -37,7 +35,6 @@ class PPO(Algorithm):
Args: Args:
model (parl.Model): model defining forward network of policy and value. model (parl.Model): model defining forward network of policy and value.
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (float): dimension of the action space. act_dim (float): dimension of the action space.
policy_lr (float): learning rate of the policy model. policy_lr (float): learning rate of the policy model.
value_lr (float): learning rate of the value model. value_lr (float): learning rate of the value model.
...@@ -47,27 +44,14 @@ class PPO(Algorithm): ...@@ -47,27 +44,14 @@ class PPO(Algorithm):
# Used to calculate probability of action in old policy # Used to calculate probability of action in old policy
self.old_policy_model = deepcopy(model.policy_model) self.old_policy_model = deepcopy(model.policy_model)
if hyperparas is not None: assert isinstance(act_dim, int)
warnings.warn( assert isinstance(policy_lr, float)
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.", assert isinstance(value_lr, float)
DeprecationWarning, assert isinstance(epsilon, float)
stacklevel=2) self.act_dim = act_dim
self.act_dim = hyperparas['act_dim'] self.policy_lr = policy_lr
self.policy_lr = hyperparas['policy_lr'] self.value_lr = value_lr
self.value_lr = hyperparas['value_lr'] self.epsilon = epsilon
if 'epsilon' in hyperparas:
self.epsilon = hyperparas['epsilon']
else:
self.epsilon = 0.2 # default
else:
assert isinstance(act_dim, int)
assert isinstance(policy_lr, float)
assert isinstance(value_lr, float)
assert isinstance(epsilon, float)
self.act_dim = act_dim
self.policy_lr = policy_lr
self.value_lr = value_lr
self.epsilon = epsilon
def _calc_logprob(self, actions, means, logvars): def _calc_logprob(self, actions, means, logvars):
""" Calculate log probabilities of actions, when given means and logvars """ Calculate log probabilities of actions, when given means and logvars
...@@ -111,49 +95,18 @@ class PPO(Algorithm): ...@@ -111,49 +95,18 @@ class PPO(Algorithm):
log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim) log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim)
return kl return kl
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" Use policy model of self.model to predict means and logvars of actions
"""
return self.predict(obs)
def predict(self, obs): def predict(self, obs):
""" Use the policy model of self.model to predict means and logvars of actions """ Use the policy model of self.model to predict means and logvars of actions
""" """
means, logvars = self.model.policy(obs) means, logvars = self.model.policy(obs)
return means return means
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='sample')
def define_sample(self, obs):
""" Use the policy model of self.model to sample actions
"""
return self.sample(obs)
def sample(self, obs): def sample(self, obs):
""" Use the policy model of self.model to sample actions """ Use the policy model of self.model to sample actions
""" """
sampled_act = self.model.policy_sample(obs) sampled_act = self.model.policy_sample(obs)
return sampled_act return sampled_act
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='policy_learn')
def define_policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
2. KLPEN loss: Adaptive KL Penalty Objective
See: https://arxiv.org/pdf/1707.02286.pdf
Args:
obs: Tensor, (batch_size, obs_dim)
actions: Tensor, (batch_size, act_dim)
advantages: Tensor (batch_size, )
beta: Tensor (1) or None
if None, use CLIP Loss; else, use KLPEN loss.
"""
return self.policy_learn(obs, actions, advantages, beta)
def policy_learn(self, obs, actions, advantages, beta=None): def policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with: """ Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective 1. CLIP loss: Clipped Surrogate Objective
...@@ -196,27 +149,11 @@ class PPO(Algorithm): ...@@ -196,27 +149,11 @@ class PPO(Algorithm):
optimizer.minimize(loss) optimizer.minimize(loss)
return loss, kl return loss, kl
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='value_predict')
def define_value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.value_predict(obs)
def value_predict(self, obs): def value_predict(self, obs):
""" Use value model of self.model to predict value of obs """ Use value model of self.model to predict value of obs
""" """
return self.model.value(obs) return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='value_learn')
def define_value_learn(self, obs, val):
""" Learn value model with square error cost
"""
return self.value_learn(obs, val)
def value_learn(self, obs, val): def value_learn(self, obs, val):
""" Learn the value model with square error cost """ Learn the value model with square error cost
""" """
...@@ -227,12 +164,7 @@ class PPO(Algorithm): ...@@ -227,12 +164,7 @@ class PPO(Algorithm):
optimizer.minimize(loss) optimizer.minimize(loss)
return loss return loss
def sync_old_policy(self, gpu_id=None): def sync_old_policy(self):
""" Synchronize weights of self.model.policy_model to self.old_policy_model """ Synchronize weights of self.model.policy_model to self.old_policy_model
""" """
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_old_policy` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.policy_model.sync_weights_to(self.old_policy_model) self.model.policy_model.sync_weights_to(self.old_policy_model)
此差异已折叠。
...@@ -27,7 +27,7 @@ __all__ = ['A2C'] ...@@ -27,7 +27,7 @@ __all__ = ['A2C']
class A2C(parl.Algorithm): class A2C(parl.Algorithm):
def __init__(self, model, config, hyperparas=None): def __init__(self, model, config):
assert isinstance(config['vf_loss_coeff'], (int, float)) assert isinstance(config['vf_loss_coeff'], (int, float))
self.model = model self.model = model
self.vf_loss_coeff = config['vf_loss_coeff'] self.vf_loss_coeff = config['vf_loss_coeff']
......
...@@ -17,7 +17,6 @@ warnings.simplefilter('default') ...@@ -17,7 +17,6 @@ warnings.simplefilter('default')
import paddle.fluid as fluid import paddle.fluid as fluid
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
from parl.core.agent_base import AgentBase from parl.core.agent_base import AgentBase
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.utils import machine_info from parl.utils import machine_info
...@@ -46,7 +45,6 @@ class Agent(AgentBase): ...@@ -46,7 +45,6 @@ class Agent(AgentBase):
This class will initialize the neural network parameters automatically, and provides an executor for users to run the programs (self.fluid_executor). This class will initialize the neural network parameters automatically, and provides an executor for users to run the programs (self.fluid_executor).
Attributes: Attributes:
gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU.
fluid_executor (fluid.Executor): executor for running programs of the agent. fluid_executor (fluid.Executor): executor for running programs of the agent.
alg (parl.algorithm): algorithm of this agent. alg (parl.algorithm): algorithm of this agent.
...@@ -65,18 +63,12 @@ class Agent(AgentBase): ...@@ -65,18 +63,12 @@ class Agent(AgentBase):
""" """
def __init__(self, algorithm, gpu_id=None): def __init__(self, algorithm):
"""Build programs by calling the method ``self.build_program()`` and run initialization function of ``fluid.default_startup_program()``. """Build programs by calling the method ``self.build_program()`` and run initialization function of ``fluid.default_startup_program()``.
Args: Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`. algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`.
gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU.
""" """
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `__init__` function in `parl.Agent` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(algorithm, Algorithm) assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm) super(Agent, self).__init__(algorithm)
...@@ -119,26 +111,6 @@ class Agent(AgentBase): ...@@ -119,26 +111,6 @@ class Agent(AgentBase):
""" """
raise NotImplementedError raise NotImplementedError
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Returns a Python dictionary containing the whole parameters of self.alg.
Returns:
a Python List containing the parameters of self.alg.
"""
return self.algorithm.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
"""Copy parameters from ``get_params()`` into this agent.
Args:
params(dict): a Python List containing the parameters of self.alg.
"""
self.algorithm.set_params(params)
def learn(self, *args, **kwargs): def learn(self, *args, **kwargs):
"""The training interface for ``Agent``. """The training interface for ``Agent``.
This function feeds the training data into the learn_program defined in ``build_program()``. This function feeds the training data into the learn_program defined in ``build_program()``.
......
...@@ -17,7 +17,6 @@ warnings.simplefilter('default') ...@@ -17,7 +17,6 @@ warnings.simplefilter('default')
from parl.core.algorithm_base import AlgorithmBase from parl.core.algorithm_base import AlgorithmBase
from parl.core.fluid.model import Model from parl.core.fluid.model import Model
from parl.utils.deprecation import deprecated
__all__ = ['Algorithm'] __all__ = ['Algorithm']
...@@ -57,47 +56,13 @@ class Algorithm(AlgorithmBase): ...@@ -57,47 +56,13 @@ class Algorithm(AlgorithmBase):
""" """
def __init__(self, model=None, hyperparas=None): def __init__(self, model=None):
""" """
Args: Args:
model(``parl.Model``): a neural network that represents a policy or a Q-value function. model(``parl.Model``): a neural network that represents a policy or a Q-value function.
hyperparas(dict): a dict storing the hyper-parameters relative to training.
""" """
if model is not None: assert isinstance(model, Model)
warnings.warn( self.model = model
"the `model` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(model, Model)
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.hp = hyperparas
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Get parameters of self.model.
Returns:
params(dict): a Python List containing the parameters of self.model.
"""
return self.model.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
""" Set parameters from ``get_params`` to the model.
Args:
params(dict ): a Python List containing the parameters of self.model.
"""
self.model.set_params(params)
def learn(self, *args, **kwargs): def learn(self, *args, **kwargs):
""" Define the loss function and create an optimizer to minize the loss. """ Define the loss function and create an optimizer to minize the loss.
......
...@@ -17,7 +17,6 @@ import paddle.fluid as fluid ...@@ -17,7 +17,6 @@ import paddle.fluid as fluid
from parl.core.fluid.layers.layer_wrappers import LayerFunc from parl.core.fluid.layers.layer_wrappers import LayerFunc
from parl.core.fluid.plutils import * from parl.core.fluid.plutils import *
from parl.core.model_base import ModelBase from parl.core.model_base import ModelBase
from parl.utils.deprecation import deprecated
from parl.utils import machine_info from parl.utils import machine_info
__all__ = ['Model'] __all__ = ['Model']
...@@ -67,30 +66,6 @@ class Model(ModelBase): ...@@ -67,30 +66,6 @@ class Model(ModelBase):
""" """
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='sync_weights_to')
def sync_params_to(self,
target_net,
gpu_id=None,
decay=0.0,
share_vars_parallel_executor=None):
"""Synchronize parameters in the model to another model (target_net).
target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
self.sync_weights_to(
target_model=target_net,
decay=decay,
share_vars_parallel_executor=share_vars_parallel_executor)
def sync_weights_to(self, def sync_weights_to(self,
target_model, target_model,
decay=0.0, decay=0.0,
...@@ -181,21 +156,6 @@ class Model(ModelBase): ...@@ -181,21 +156,6 @@ class Model(ModelBase):
else: else:
self._cached_fluid_executor.run(fetch_list=[]) self._cached_fluid_executor.run(fetch_list=[])
@property
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='parameters')
def parameter_names(self):
"""Get names of all parameters in this ``Model``.
Only parameters created by ``parl.layers`` are included.
The order of parameter names is consistent among
different instances of the same `Model`.
Returns:
param_names(list): list of string containing parameter names of all parameters.
"""
return self.parameters()
def parameters(self): def parameters(self):
"""Get names of all parameters in this ``Model``. """Get names of all parameters in this ``Model``.
...@@ -223,26 +183,6 @@ class Model(ModelBase): ...@@ -223,26 +183,6 @@ class Model(ModelBase):
self._parameter_names = self._get_parameter_names(self) self._parameter_names = self._get_parameter_names(self)
return self._parameter_names return self._parameter_names
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Return a Python list containing parameters of current model.
Returns:
parameters: a Python list containing parameters of the current model.
"""
return self.get_weights()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params, gpu_id=None):
"""Set parameters in the model with params.
Args:
params (List): List of numpy array .
"""
self.set_weights(weights=params)
def get_weights(self): def get_weights(self):
"""Returns a Python list containing parameters of current model. """Returns a Python list containing parameters of current model.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.framework` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.model import *
from parl.core.fluid.algorithm import *
from parl.core.fluid.agent import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.agent_base.Agent` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Agent` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.agent import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.algorithm_base.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Algorithm` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.algorithm import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.model_base.Model` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Model` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.model import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.policy_distribution` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.policy_distribution` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.policy_distribution import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.layers` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import layers` or `import parl; parl.layers` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.layers import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
from parl.core.fluid.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
from parl.core.fluid.plutils.common import *
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册