未验证 提交 ce152a62 编写于 作者: R rical730 提交者: GitHub

update lr interface and support training on single gpu (#415)

* update lr interface and support  training on single gpu

* yapf

* update warning message

* update warning message
上级 b966fa78
...@@ -17,6 +17,7 @@ import parl ...@@ -17,6 +17,7 @@ import parl
from parl import layers from parl import layers
from paddle import fluid from paddle import fluid
from parl.utils import ReplayMemory from parl.utils import ReplayMemory
from parl.utils import machine_info, get_gpu_count
class MAAgent(parl.Agent): class MAAgent(parl.Agent):
...@@ -47,6 +48,10 @@ class MAAgent(parl.Agent): ...@@ -47,6 +48,10 @@ class MAAgent(parl.Agent):
act_dim=self.act_dim_n[agent_index]) act_dim=self.act_dim_n[agent_index])
self.global_train_step = 0 self.global_train_step = 0
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
super(MAAgent, self).__init__(algorithm) super(MAAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally. # Attention: In the beginning, sync target model totally.
......
...@@ -101,7 +101,8 @@ def train_agent(): ...@@ -101,7 +101,8 @@ def train_agent():
act_space=env.action_space, act_space=env.action_space,
gamma=args.gamma, gamma=args.gamma,
tau=args.tau, tau=args.tau,
lr=args.lr) critic_lr=args.critic_lr,
actor_lr=args.actor_lr)
agent = MAAgent( agent = MAAgent(
algorithm, algorithm,
agent_index=i, agent_index=i,
...@@ -195,10 +196,15 @@ if __name__ == '__main__': ...@@ -195,10 +196,15 @@ if __name__ == '__main__':
help='statistical interval of save model or count reward') help='statistical interval of save model or count reward')
# Core training parameters # Core training parameters
parser.add_argument( parser.add_argument(
'--lr', '--critic_lr',
type=float, type=float,
default=1e-3, default=1e-3,
help='learning rate for Adam optimizer') help='learning rate for the critic model')
parser.add_argument(
'--actor_lr',
type=float,
default=1e-3,
help='learning rate of the actor model')
parser.add_argument( parser.add_argument(
'--gamma', type=float, default=0.95, help='discount factor') '--gamma', type=float, default=0.95, help='discount factor')
parser.add_argument( parser.add_argument(
......
...@@ -53,7 +53,9 @@ class MADDPG(Algorithm): ...@@ -53,7 +53,9 @@ class MADDPG(Algorithm):
act_space=None, act_space=None,
gamma=None, gamma=None,
tau=None, tau=None,
lr=None): lr=None,
actor_lr=None,
critic_lr=None):
""" MADDPG algorithm """ MADDPG algorithm
Args: Args:
...@@ -63,19 +65,38 @@ class MADDPG(Algorithm): ...@@ -63,19 +65,38 @@ class MADDPG(Algorithm):
act_space: action_space, gym space act_space: action_space, gym space
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model tau (float): decay coefficient when updating the weights of self.target_model with self.model
lr (float): learning rate lr (float): learning rate, lr will be assigned to both critic_lr and actor_lr
critic_lr (float): learning rate of the critic model
actor_lr (float): learning rate of the actor model
""" """
assert isinstance(agent_index, int) assert isinstance(agent_index, int)
assert isinstance(act_space, list) assert isinstance(act_space, list)
assert isinstance(gamma, float) assert isinstance(gamma, float)
assert isinstance(tau, float) assert isinstance(tau, float)
assert isinstance(lr, float) # compatible upgrade of lr
if lr is None:
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
else:
assert isinstance(lr, float)
assert actor_lr is None, 'no need to set `actor_lr` if `lr` is not None'
assert critic_lr is None, 'no need to set `critic_lr` if `lr` is not None'
critic_lr = lr
actor_lr = lr
warnings.warn(
"the `lr` argument of `__init__` function in `parl.Algorithms.MADDPG` is deprecated \
since version 1.4 and will be removed in version 2.0. \
Recommend to use `actor_lr` and `critic_lr`. ",
DeprecationWarning,
stacklevel=2)
self.agent_index = agent_index self.agent_index = agent_index
self.act_space = act_space self.act_space = act_space
self.gamma = gamma self.gamma = gamma
self.tau = tau self.tau = tau
self.lr = lr self.lr = lr
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.model = model self.model = model
self.target_model = deepcopy(model) self.target_model = deepcopy(model)
...@@ -145,7 +166,7 @@ class MADDPG(Algorithm): ...@@ -145,7 +166,7 @@ class MADDPG(Algorithm):
clip=fluid.clip.GradientClipByNorm(clip_norm=0.5), clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
param_list=self.model.get_actor_params()) param_list=self.model.get_actor_params())
optimizer = fluid.optimizer.AdamOptimizer(self.lr) optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
optimizer.minimize(cost, parameter_list=self.model.get_actor_params()) optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
return cost return cost
...@@ -157,7 +178,7 @@ class MADDPG(Algorithm): ...@@ -157,7 +178,7 @@ class MADDPG(Algorithm):
clip=fluid.clip.GradientClipByNorm(clip_norm=0.5), clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
param_list=self.model.get_critic_params()) param_list=self.model.get_critic_params())
optimizer = fluid.optimizer.AdamOptimizer(self.lr) optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
optimizer.minimize(cost, parameter_list=self.model.get_critic_params()) optimizer.minimize(cost, parameter_list=self.model.get_critic_params())
return cost return cost
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册