diff --git a/README.cn.md b/README.cn.md index 09f1df56a90bcc36dd0971038dfd15de501034ec..dd154a4e634374d5765727e86607e5c178719056 100644 --- a/README.cn.md +++ b/README.cn.md @@ -78,6 +78,7 @@ pip install parl - [A2C](examples/A2C/) - [TD3](examples/TD3/) - [SAC](examples/SAC/) +- [MADDPG](examples/MADDPG/) - [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) - [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/) diff --git a/README.md b/README.md index 5245c349951b28a1a6c74e0a15cfc89d22edfaf3..7bec38c8a253b4cc53fb4cc0ac0dee3221181b05 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ pip install parl - [A2C](examples/A2C/) - [TD3](examples/TD3/) - [SAC](examples/SAC/) +- [MADDPG](examples/MADDPG/) - [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) - [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/) diff --git a/examples/MADDPG/.benchmark/MADDPG_simple.gif b/examples/MADDPG/.benchmark/MADDPG_simple.gif new file mode 100644 index 0000000000000000000000000000000000000000..7c2b95debb594975e402e9e1ce3e8eb710aa3ee3 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple.png b/examples/MADDPG/.benchmark/MADDPG_simple.png new file mode 100644 index 0000000000000000000000000000000000000000..fafa5c17ae1aca610c17fe1bcd90d3af5d6ef7ac Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif new file mode 100644 index 0000000000000000000000000000000000000000..f5038ac1078954550cc1fd2b38fb4d733aa2dc96 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png new file mode 100644 index 0000000000000000000000000000000000000000..3e3e2d2c54ce42b12cbccf8e884f5b5354a6bd27 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png b/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png new file mode 100644 index 0000000000000000000000000000000000000000..f13acbc996cd1e5931f988fb4ebae8a772295426 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_push.gif b/examples/MADDPG/.benchmark/MADDPG_simple_push.gif new file mode 100644 index 0000000000000000000000000000000000000000..faad8196614c31f31b5aded2ab379357a5c3a5c3 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_push.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_push.png b/examples/MADDPG/.benchmark/MADDPG_simple_push.png new file mode 100644 index 0000000000000000000000000000000000000000..a368f63f846f47cd2520e1af65db10ba1c53dc5c Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_push.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif b/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif new file mode 100644 index 0000000000000000000000000000000000000000..e7cf37f9aa664e8347da9d8cb4c376e3f2798607 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_reference.png b/examples/MADDPG/.benchmark/MADDPG_simple_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..ea94f14efb04227b3571c983505e480778b5a463 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_reference.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif new file mode 100644 index 0000000000000000000000000000000000000000..6ef772bb966c247eff1d30da70dd6e969b49005a Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png new file mode 100644 index 0000000000000000000000000000000000000000..ad4e0fc4d0bb446f26a445ce58e4f7d47def4eb5 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif b/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif new file mode 100644 index 0000000000000000000000000000000000000000..013f95d1d18e8dd90c7adb526aeeed475841cb9c Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_spread.png b/examples/MADDPG/.benchmark/MADDPG_simple_spread.png new file mode 100644 index 0000000000000000000000000000000000000000..4ae93288ff4dd5fda1e35006b50102dda5695250 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_spread.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif b/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif new file mode 100644 index 0000000000000000000000000000000000000000..703ca69f5e529f70be6fc45fef62871f39bff771 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_tag.png b/examples/MADDPG/.benchmark/MADDPG_simple_tag.png new file mode 100644 index 0000000000000000000000000000000000000000..4d21b27d350bc6ebd4d427e5fa460eb8863e95d9 Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_tag.png differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif new file mode 100644 index 0000000000000000000000000000000000000000..1d6c4df5e76d4bf5dd43e501bea01eeccdaadd7e Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif differ diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png new file mode 100644 index 0000000000000000000000000000000000000000..78f77fd7bc2f8bf4986aa20c241d8274c6fb9cde Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png differ diff --git a/examples/MADDPG/README.md b/examples/MADDPG/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d37b553de416fe9176119f654508774d8bb1278c --- /dev/null +++ b/examples/MADDPG/README.md @@ -0,0 +1,112 @@ +## Reproduce MADDPG with PARL +Based on PARL, the MADDPG algorithm of deep reinforcement learning has been reproduced. + ++ paper: +[ Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments](https://arxiv.org/abs/1706.02275) + +### Multi-agent particle environment introduction +A simple multi-agent particle world based on gym. Please see [here](https://github.com/openai/multiagent-particle-envs) to install and know more about the environment. + +### Benchmark result +Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes). + + + + + + + + + + + + + + +
+simple
+MADDPG_simple +
+simple_adversary
+MADDPG_simple_adversary +
+simple_push
+MADDPG_simple_push +
+simple_reference
+MADDPG_simple_reference +
+simple_speaker_listener
+MADDPG_simple_speaker_listener +
+simple_spread
+MADDPG_simple_spread +
+simple_tag
+MADDPG_simple_tag +
+simple_world_comm
+MADDPG_simple_world_comm +
+ +### Experiments result +Display after 25000 episodes. + + + + + + + + + + + + + + +
+simple
+MADDPG_simple +
+simple_adversary
+MADDPG_simple_adversary +
+simple_push
+MADDPG_simple_push +
+simple_reference
+MADDPG_simple_reference +
+simple_speaker_listener
+MADDPG_simple_speaker_listener +
+simple_spread
+MADDPG_simple_spread +
+simple_tag
+MADDPG_simple_tag +
+simple_world_comm
+MADDPG_simple_world_comm +
+ + +## How to use +### Dependencies: ++ python3.5+ ++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs) ++ gym + +### Start Training: +``` +# To train an agent for simple_speaker_listener scenario +python train.py + +# To train for other scenario, model is automatically saved every 1000 episodes +# python train.py --env [ENV_NAME] + +# To show animation effects after training +# python train.py --env [ENV_NAME] --show --restore diff --git a/examples/MADDPG/simple_agent.py b/examples/MADDPG/simple_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0e47dd5b085643a3c9ae250b6826a783c71771 --- /dev/null +++ b/examples/MADDPG/simple_agent.py @@ -0,0 +1,180 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from parl import layers +from paddle import fluid +from parl.utils import ReplayMemory + + +class MAAgent(parl.Agent): + def __init__(self, + algorithm, + agent_index=None, + obs_dim_n=None, + act_dim_n=None, + batch_size=None, + speedup=False): + assert isinstance(agent_index, int) + assert isinstance(obs_dim_n, list) + assert isinstance(act_dim_n, list) + assert isinstance(batch_size, int) + assert isinstance(speedup, bool) + self.agent_index = agent_index + self.obs_dim_n = obs_dim_n + self.act_dim_n = act_dim_n + self.batch_size = batch_size + self.speedup = speedup + self.n = len(act_dim_n) + + self.memory_size = int(1e6) + self.min_memory_size = batch_size * 25 # batch_size * args.max_episode_len + self.rpm = ReplayMemory( + max_size=self.memory_size, + obs_dim=self.obs_dim_n[agent_index], + act_dim=self.act_dim_n[agent_index]) + self.global_train_step = 0 + + super(MAAgent, self).__init__(algorithm) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + self.next_q_program = fluid.Program() + self.next_a_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', + shape=[self.obs_dim_n[self.agent_index]], + dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs_n = [ + layers.data( + name='obs' + str(i), + shape=[self.obs_dim_n[i]], + dtype='float32') for i in range(self.n) + ] + act_n = [ + layers.data( + name='act' + str(i), + shape=[self.act_dim_n[i]], + dtype='float32') for i in range(self.n) + ] + target_q = layers.data(name='target_q', shape=[], dtype='float32') + self.critic_cost = self.alg.learn(obs_n, act_n, target_q) + + with fluid.program_guard(self.next_q_program): + obs_n = [ + layers.data( + name='obs' + str(i), + shape=[self.obs_dim_n[i]], + dtype='float32') for i in range(self.n) + ] + act_n = [ + layers.data( + name='act' + str(i), + shape=[self.act_dim_n[i]], + dtype='float32') for i in range(self.n) + ] + self.next_Q = self.alg.Q_next(obs_n, act_n) + + with fluid.program_guard(self.next_a_program): + obs = layers.data( + name='obs', + shape=[self.obs_dim_n[self.agent_index]], + dtype='float32') + self.next_action = self.alg.predict_next(obs) + + if self.speedup: + self.pred_program = parl.compile(self.pred_program) + self.learn_program = parl.compile(self.learn_program, + self.critic_cost) + self.next_q_program = parl.compile(self.next_q_program) + self.next_a_program = parl.compile(self.next_a_program) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + obs = obs.astype('float32') + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act[0] + + def learn(self, agents): + self.global_train_step += 1 + + # only update parameter every 100 steps + if self.global_train_step % 100 != 0: + return 0.0 + + if self.rpm.size() <= self.min_memory_size: + return 0.0 + + batch_obs_n = [] + batch_act_n = [] + batch_obs_new_n = [] + + rpm_sample_index = self.rpm.make_index(self.batch_size) + for i in range(self.n): + batch_obs, batch_act, _, batch_obs_new, _ \ + = agents[i].rpm.sample_batch_by_index(rpm_sample_index) + batch_obs_n.append(batch_obs) + batch_act_n.append(batch_act) + batch_obs_new_n.append(batch_obs_new) + _, _, batch_rew, _, batch_isOver \ + = self.rpm.sample_batch_by_index(rpm_sample_index) + + # compute target q + target_q = 0.0 + target_act_next_n = [] + for i in range(self.n): + feed = {'obs': batch_obs_new_n[i]} + target_act_next = agents[i].fluid_executor.run( + agents[i].next_a_program, + feed=feed, + fetch_list=[agents[i].next_action])[0] + target_act_next_n.append(target_act_next) + feed_obs = {'obs' + str(i): batch_obs_new_n[i] for i in range(self.n)} + feed_act = { + 'act' + str(i): target_act_next_n[i] + for i in range(self.n) + } + feed = feed_obs.copy() + feed.update(feed_act) # merge two dict + target_q_next = self.fluid_executor.run( + self.next_q_program, feed=feed, fetch_list=[self.next_Q])[0] + target_q += ( + batch_rew + self.alg.gamma * (1.0 - batch_isOver) * target_q_next) + + feed_obs = {'obs' + str(i): batch_obs_n[i] for i in range(self.n)} + feed_act = {'act' + str(i): batch_act_n[i] for i in range(self.n)} + target_q = target_q.astype('float32') + feed = feed_obs.copy() + feed.update(feed_act) + feed['target_q'] = target_q + critic_cost = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0] + + self.alg.sync_target() + return critic_cost + + def add_experience(self, obs, act, reward, next_obs, terminal): + self.rpm.append(obs, act, reward, next_obs, terminal) diff --git a/examples/MADDPG/simple_model.py b/examples/MADDPG/simple_model.py new file mode 100644 index 0000000000000000000000000000000000000000..00e130dded41486bc541a58bdc5967e4b30774fb --- /dev/null +++ b/examples/MADDPG/simple_model.py @@ -0,0 +1,88 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl +from parl import layers + + +class MAModel(parl.Model): + def __init__(self, act_dim): + self.actor_model = ActorModel(act_dim) + self.critic_model = CriticModel() + + def policy(self, obs): + return self.actor_model.policy(obs) + + def value(self, obs, act): + return self.critic_model.value(obs, act) + + def get_actor_params(self): + return self.actor_model.parameters() + + def get_critic_params(self): + return self.critic_model.parameters() + + +class ActorModel(parl.Model): + def __init__(self, act_dim): + hid1_size = 64 + hid2_size = 64 + + self.fc1 = layers.fc( + size=hid1_size, + act='relu', + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + self.fc2 = layers.fc( + size=hid2_size, + act='relu', + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + self.fc3 = layers.fc( + size=act_dim, + act=None, + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + + def policy(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.fc3(hid2) + means = means + return means + + +class CriticModel(parl.Model): + def __init__(self): + hid1_size = 64 + hid2_size = 64 + + self.fc1 = layers.fc( + size=hid1_size, + act='relu', + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + self.fc2 = layers.fc( + size=hid2_size, + act='relu', + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + self.fc3 = layers.fc( + size=1, + act=None, + param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1)) + + def value(self, obs_n, act_n): + inputs = layers.concat(obs_n + act_n, axis=1) + hid1 = self.fc1(inputs) + hid2 = self.fc2(hid1) + Q = self.fc3(hid2) + Q = layers.squeeze(Q, axes=[1]) + return Q diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e20dcdb4fd35638432fb1666b76b30c2a388d8 --- /dev/null +++ b/examples/MADDPG/train.py @@ -0,0 +1,226 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import numpy as np +from simple_model import MAModel +from simple_agent import MAAgent +import parl +from parl.env.multiagent_simple_env import MAenv +from parl.utils import logger, tensorboard + + +def run_episode(env, agents): + obs_n = env.reset() + total_reward = 0 + agents_reward = [0 for _ in range(env.n)] + steps = 0 + while True: + steps += 1 + action_n = [agent.predict(obs) for agent, obs in zip(agents, obs_n)] + next_obs_n, reward_n, done_n, _ = env.step(action_n) + done = all(done_n) + terminal = (steps >= args.max_step_per_episode) + + # store experience + for i, agent in enumerate(agents): + agent.add_experience(obs_n[i], action_n[i], reward_n[i], + next_obs_n[i], done_n[i]) + + # compute reward of every agent + obs_n = next_obs_n + for i, reward in enumerate(reward_n): + total_reward += reward + agents_reward[i] += reward + + # check the end of an episode + if done or terminal: + break + + # show animation + if args.show: + time.sleep(0.1) + env.render() + + # show model effect without training + if args.restore and args.show: + continue + + # learn policy + for i, agent in enumerate(agents): + critic_loss = agent.learn(agents) + tensorboard.add_scalar('critic_loss_%d' % i, critic_loss, + agent.global_train_step) + + return total_reward, agents_reward, steps + + +def train_agent(): + env = MAenv(args.env) + logger.info('agent num: {}'.format(env.n)) + logger.info('observation_space: {}'.format(env.observation_space)) + logger.info('action_space: {}'.format(env.action_space)) + logger.info('obs_shape_n: {}'.format(env.obs_shape_n)) + logger.info('act_shape_n: {}'.format(env.act_shape_n)) + for i in range(env.n): + logger.info('agent {} obs_low:{} obs_high:{}'.format( + i, env.observation_space[i].low, env.observation_space[i].high)) + logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i])) + if ('low' in dir(env.action_space[i])): + logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format( + i, env.action_space[i].low, env.action_space[i].high, + env.action_space[i].shape)) + logger.info('num_discrete_space:{}'.format( + env.action_space[i].num_discrete_space)) + + from gym import spaces + from multiagent.multi_discrete import MultiDiscrete + for space in env.action_space: + assert (isinstance(space, spaces.Discrete) + or isinstance(space, MultiDiscrete)) + + agents = [] + for i in range(env.n): + model = MAModel(env.act_shape_n[i]) + algorithm = parl.algorithms.MADDPG( + model, + agent_index=i, + act_space=env.action_space, + gamma=args.gamma, + tau=args.tau, + lr=args.lr) + agent = MAAgent( + algorithm, + agent_index=i, + obs_dim_n=env.obs_shape_n, + act_dim_n=env.act_shape_n, + batch_size=args.batch_size, + speedup=(not args.restore)) + agents.append(agent) + total_steps = 0 + total_episodes = 0 + + episode_rewards = [] # sum of rewards for all agents + agent_rewards = [[] for _ in range(env.n)] # individual agent reward + final_ep_rewards = [] # sum of rewards for training curve + final_ep_ag_rewards = [] # agent rewards for training curve + + if args.restore: + # restore modle + for i in range(len(agents)): + model_file = args.model_dir + '/agent_' + str(i) + '.ckpt' + if not os.path.exists(model_file): + logger.info('model file {} does not exits'.format(model_file)) + raise Exception + agents[i].restore(model_file) + + t_start = time.time() + logger.info('Starting...') + while total_episodes <= args.max_episodes: + # run an episode + ep_reward, ep_agent_rewards, steps = run_episode(env, agents) + if args.show: + print('episode {}, reward {}, steps {}'.format( + total_episodes, ep_reward, steps)) + + # Record reward + total_steps += steps + total_episodes += 1 + episode_rewards.append(ep_reward) + for i in range(env.n): + agent_rewards[i].append(ep_agent_rewards[i]) + + # Keep track of final episode reward + if total_episodes % args.stat_rate == 0: + mean_episode_reward = np.mean(episode_rewards[-args.stat_rate:]) + final_ep_rewards.append(mean_episode_reward) + for rew in agent_rewards: + final_ep_ag_rewards.append(np.mean(rew[-args.stat_rate:])) + use_time = round(time.time() - t_start, 3) + logger.info( + 'Steps: {}, Episodes: {}, Mean episode reward: {}, Time: {}'. + format(total_steps, total_episodes, mean_episode_reward, + use_time)) + t_start = time.time() + tensorboard.add_scalar('mean_episode_reward/episode', + mean_episode_reward, total_episodes) + tensorboard.add_scalar('mean_episode_reward/steps', + mean_episode_reward, total_steps) + tensorboard.add_scalar('use_time/1000episode', use_time, + total_episodes) + + # save model + if not args.restore: + os.makedirs(os.path.dirname(args.model_dir), exist_ok=True) + for i in range(len(agents)): + model_name = '/agent_' + str(i) + '.ckpt' + agents[i].save(args.model_dir + model_name) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # Environment + parser.add_argument( + '--env', + type=str, + default='simple_speaker_listener', + help='scenario of MultiAgentEnv') + parser.add_argument( + '--max_step_per_episode', + type=int, + default=25, + help='maximum step per episode') + parser.add_argument( + '--max_episodes', + type=int, + default=25000, + help='stop condition:number of episodes') + parser.add_argument( + '--stat_rate', + type=int, + default=1000, + help='statistical interval of save model or count reward') + # Core training parameters + parser.add_argument( + '--lr', + type=float, + default=1e-3, + help='learning rate for Adam optimizer') + parser.add_argument( + '--gamma', type=float, default=0.95, help='discount factor') + parser.add_argument( + '--batch_size', + type=int, + default=1024, + help='number of episodes to optimize at the same time') + parser.add_argument('--tau', type=int, default=0.01, help='soft update') + # auto save model, optional restore model + parser.add_argument( + '--show', action='store_true', default=False, help='display or not') + parser.add_argument( + '--restore', + action='store_true', + default=False, + help='restore or not, must have model_dir') + parser.add_argument( + '--model_dir', + type=str, + default='./model', + help='directory for saving model') + + args = parser.parse_args() + + train_agent() diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py index 3f005ac623ed5cf7fc95b2f9242c962dbfbf3302..dac58ab6bf7815e95ecc21269b67116d567fa576 100644 --- a/parl/algorithms/fluid/__init__.py +++ b/parl/algorithms/fluid/__init__.py @@ -14,6 +14,7 @@ from parl.algorithms.fluid.a3c import * from parl.algorithms.fluid.ddpg import * +from parl.algorithms.fluid.maddpg import * from parl.algorithms.fluid.dqn import * from parl.algorithms.fluid.ddqn import * from parl.algorithms.fluid.policy_gradient import * diff --git a/parl/algorithms/fluid/maddpg.py b/parl/algorithms/fluid/maddpg.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb166090c4d727195a775f2ac1cf956cce13922 --- /dev/null +++ b/parl/algorithms/fluid/maddpg.py @@ -0,0 +1,159 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +from parl.core.fluid import layers +from copy import deepcopy +from paddle import fluid +from parl.core.fluid.algorithm import Algorithm + +__all__ = ['MADDPG'] + +from gym import spaces +from parl.core.fluid.policy_distribution import SoftCategoricalDistribution +from parl.core.fluid.policy_distribution import SoftMultiCategoricalDistribution + + +def SoftPDistribution(logits, act_space): + if (isinstance(act_space, spaces.Discrete)): + return SoftCategoricalDistribution(logits) + # is instance of multiagent.multi_discrete.MultiDiscrete + elif (hasattr(act_space, 'num_discrete_space')): + return SoftMultiCategoricalDistribution(logits, act_space.low, + act_space.high) + else: + raise NotImplementedError + + +class MADDPG(Algorithm): + def __init__(self, + model, + agent_index=None, + act_space=None, + gamma=None, + tau=None, + lr=None): + """ MADDPG algorithm + + Args: + model (parl.Model): forward network of actor and critic. + The function get_actor_params() of model should be implemented. + agent_index: index of agent, in multiagent env + act_space: action_space, gym space + gamma (float): discounted factor for reward computation. + tau (float): decay coefficient when updating the weights of self.target_model with self.model + lr (float): learning rate + """ + + assert isinstance(agent_index, int) + assert isinstance(act_space, list) + assert isinstance(gamma, float) + assert isinstance(tau, float) + assert isinstance(lr, float) + self.agent_index = agent_index + self.act_space = act_space + self.gamma = gamma + self.tau = tau + self.lr = lr + + self.model = model + self.target_model = deepcopy(model) + + def predict(self, obs): + """ input: + obs: observation, shape([B] + shape of obs_n[agent_index]) + output: + act: action, shape([B] + shape of act_n[agent_index]) + """ + this_policy = self.model.policy(obs) + this_action = SoftPDistribution( + logits=this_policy, + act_space=self.act_space[self.agent_index]).sample() + return this_action + + def predict_next(self, obs): + """ input: observation, shape([B] + shape of obs_n[agent_index]) + output: action, shape([B] + shape of act_n[agent_index]) + """ + next_policy = self.target_model.policy(obs) + next_action = SoftPDistribution( + logits=next_policy, + act_space=self.act_space[self.agent_index]).sample() + return next_action + + def Q(self, obs_n, act_n): + """ input: + obs_n: all agents' observation, shape([B] + shape of obs_n) + output: + act_n: all agents' action, shape([B] + shape of act_n) + """ + return self.model.value(obs_n, act_n) + + def Q_next(self, obs_n, act_n): + """ input: + obs_n: all agents' observation, shape([B] + shape of obs_n) + output: + act_n: all agents' action, shape([B] + shape of act_n) + """ + return self.target_model.value(obs_n, act_n) + + def learn(self, obs_n, act_n, target_q): + """ update actor and critic model with MADDPG algorithm + """ + actor_cost = self._actor_learn(obs_n, act_n) + critic_cost = self._critic_learn(obs_n, act_n, target_q) + return critic_cost + + def _actor_learn(self, obs_n, act_n): + i = self.agent_index + this_policy = self.model.policy(obs_n[i]) + sample_this_action = SoftPDistribution( + logits=this_policy, + act_space=self.act_space[self.agent_index]).sample() + + action_input_n = act_n + [] + action_input_n[i] = sample_this_action + eval_q = self.Q(obs_n, action_input_n) + act_cost = layers.reduce_mean(-1.0 * eval_q) + + act_reg = layers.reduce_mean(layers.square(this_policy)) + + cost = act_cost + act_reg * 1e-3 + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByNorm(clip_norm=0.5), + param_list=self.model.get_actor_params()) + + optimizer = fluid.optimizer.AdamOptimizer(self.lr) + optimizer.minimize(cost, parameter_list=self.model.get_actor_params()) + return cost + + def _critic_learn(self, obs_n, act_n, target_q): + pred_q = self.Q(obs_n, act_n) + cost = layers.reduce_mean(layers.square_error_cost(pred_q, target_q)) + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByNorm(clip_norm=0.5), + param_list=self.model.get_critic_params()) + + optimizer = fluid.optimizer.AdamOptimizer(self.lr) + optimizer.minimize(cost, parameter_list=self.model.get_critic_params()) + return cost + + def sync_target(self, decay=None): + if decay is None: + decay = 1.0 - self.tau + self.model.sync_weights_to(self.target_model, decay=decay) diff --git a/parl/core/fluid/policy_distribution.py b/parl/core/fluid/policy_distribution.py index 4d876062a3e19ab0101a83b3d325de6200932e67..73bc93b78ba25aa2f0c197fba868e2269afe4fa1 100644 --- a/parl/core/fluid/policy_distribution.py +++ b/parl/core/fluid/policy_distribution.py @@ -14,7 +14,10 @@ from parl.core.fluid import layers -__all__ = ['PolicyDistribution', 'CategoricalDistribution'] +__all__ = [ + 'PolicyDistribution', 'CategoricalDistribution', + 'SoftCategoricalDistribution', 'SoftMultiCategoricalDistribution' +] class PolicyDistribution(object): @@ -79,7 +82,6 @@ class CategoricalDistribution(PolicyDistribution): Returns: actions_log_prob: A float32 tensor with shape [BATCH_SIZE] """ - assert len(actions.shape) == 1 logits = self.logits - layers.reduce_max(self.logits, dim=1) @@ -122,3 +124,88 @@ class CategoricalDistribution(PolicyDistribution): (logits - layers.log(z) - other_logits + layers.log(other_z)), dim=1) return kl + + +class SoftCategoricalDistribution(CategoricalDistribution): + """Categorical distribution with noise for discrete action spaces""" + + def __init__(self, logits): + """ + Args: + logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits + """ + self.logits = logits + super(SoftCategoricalDistribution, self).__init__(logits) + + def sample(self): + """ + Returns: + sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action, + with noise to keep the target close to the original action. + """ + eps = 1e-4 + logits_shape = layers.cast(layers.shape(self.logits), dtype='int64') + uniform = layers.uniform_random(logits_shape, min=eps, max=1.0 - eps) + soft_uniform = layers.log(-1.0 * layers.log(uniform)) + return layers.softmax(self.logits - soft_uniform, axis=-1) + + +class SoftMultiCategoricalDistribution(PolicyDistribution): + """Categorical distribution with noise for MultiDiscrete action spaces.""" + + def __init__(self, logits, low, high): + """ + Args: + logits: A float32 tensor with shape [BATCH_SIZE, LEN_MultiDiscrete, NUM_ACTIONS] of unnormalized policy logits + low: lower bounds of sample action + high: Upper bounds of action + """ + self.logits = logits + self.low = low + self.high = high + self.categoricals = list( + map( + SoftCategoricalDistribution, + layers.split( + input=logits, + num_or_sections=list(high - low + 1), + dim=len(logits.shape) - 1))) + + def sample(self): + """ + Returns: + sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action, + with noise to keep the target close to the original action. + """ + cate_list = [] + for i in range(len(self.categoricals)): + cate_list.append(self.low[i] + self.categoricals[i].sample()) + return layers.concat(cate_list, axis=-1) + + def layers_add_n(self, input_list): + """ + Adds all input tensors element-wise, can replace tf.add_n + """ + assert len(input_list) >= 1 + res = input_list[0] + for i in range(1, len(input_list)): + res = layers.elementwise_add(res, input_list[i]) + return res + + def entropy(self): + """ + Returns: + entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution. + """ + return self.layers_add_n([p.entropy() for p in self.categoricals]) + + def kl(self, other): + """ + Args: + other: object of SoftCategoricalDistribution + + Returns: + kl: A float32 tensor with shape [BATCH_SIZE] + """ + return self.layers_add_n( + [p.kl(q) for p, q in zip(self.categoricals, other.categoricals)]) diff --git a/parl/env/multiagent_simple_env.py b/parl/env/multiagent_simple_env.py new file mode 100644 index 0000000000000000000000000000000000000000..222c4f7a861d5e1516619783fd94f8a1866a64c2 --- /dev/null +++ b/parl/env/multiagent_simple_env.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gym import spaces +from multiagent.multi_discrete import MultiDiscrete +from multiagent.environment import MultiAgentEnv +import multiagent.scenarios as scenarios + + +class MAenv(MultiAgentEnv): + """ multiagent environment warppers for maddpg + """ + + def __init__(self, scenario_name): + # load scenario from script + scenario = scenarios.load(scenario_name + ".py").Scenario() + # create world + world = scenario.make_world() + # initial multiagent environment + super().__init__(world, scenario.reset_world, scenario.reward, + scenario.observation) + self.obs_shape_n = [ + self.get_shape(self.observation_space[i]) for i in range(self.n) + ] + self.act_shape_n = [ + self.get_shape(self.action_space[i]) for i in range(self.n) + ] + + def get_shape(self, input_space): + """ + Args: + input_space: environment space + + Returns: + space shape + """ + if (isinstance(input_space, spaces.Box)): + if (len(input_space.shape) == 1): + return input_space.shape[0] + else: + return input_space.shape + elif (isinstance(input_space, spaces.Discrete)): + return input_space.n + elif (isinstance(input_space, MultiDiscrete)): + return sum(input_space.high - input_space.low + 1) + else: + print('[Error] shape is {}, not Box or Discrete or MultiDiscrete'. + format(input_space.shape)) + raise NotImplementedError diff --git a/parl/utils/replay_memory.py b/parl/utils/replay_memory.py index 051c6fb30f94b933bdafcfad16d284a1a1610432..a1a3b4a8ed2cb327b7a4ec34770ea09ccf917ad8 100755 --- a/parl/utils/replay_memory.py +++ b/parl/utils/replay_memory.py @@ -34,10 +34,8 @@ class ReplayMemory(object): self._curr_pos = 0 def sample_batch(self, batch_size): - # index mapping to avoid sampling saving example batch_idx = np.random.randint( self._curr_size - 300 - 1, size=batch_size) - batch_idx = (self._curr_pos + 300 + batch_idx) % self._curr_size obs = self.obs[batch_idx] reward = self.reward[batch_idx] @@ -46,6 +44,19 @@ class ReplayMemory(object): terminal = self.terminal[batch_idx] return obs, action, reward, next_obs, terminal + def make_index(self, batch_size): + batch_idx = np.random.randint( + self._curr_size - 300 - 1, size=batch_size) + return batch_idx + + def sample_batch_by_index(self, batch_idx): + obs = self.obs[batch_idx] + reward = self.reward[batch_idx] + action = self.action[batch_idx] + next_obs = self.next_obs[batch_idx] + terminal = self.terminal[batch_idx] + return obs, action, reward, next_obs, terminal + def append(self, obs, act, reward, next_obs, terminal): if self._curr_size < self.max_size: self._curr_size += 1