diff --git a/README.cn.md b/README.cn.md
index 09f1df56a90bcc36dd0971038dfd15de501034ec..dd154a4e634374d5765727e86607e5c178719056 100644
--- a/README.cn.md
+++ b/README.cn.md
@@ -78,6 +78,7 @@ pip install parl
- [A2C](examples/A2C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
+- [MADDPG](examples/MADDPG/)
- [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/)
diff --git a/README.md b/README.md
index 5245c349951b28a1a6c74e0a15cfc89d22edfaf3..7bec38c8a253b4cc53fb4cc0ac0dee3221181b05 100644
--- a/README.md
+++ b/README.md
@@ -81,6 +81,7 @@ pip install parl
- [A2C](examples/A2C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
+- [MADDPG](examples/MADDPG/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/)
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple.gif b/examples/MADDPG/.benchmark/MADDPG_simple.gif
new file mode 100644
index 0000000000000000000000000000000000000000..7c2b95debb594975e402e9e1ce3e8eb710aa3ee3
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple.png b/examples/MADDPG/.benchmark/MADDPG_simple.png
new file mode 100644
index 0000000000000000000000000000000000000000..fafa5c17ae1aca610c17fe1bcd90d3af5d6ef7ac
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif
new file mode 100644
index 0000000000000000000000000000000000000000..f5038ac1078954550cc1fd2b38fb4d733aa2dc96
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e3e2d2c54ce42b12cbccf8e884f5b5354a6bd27
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_adversary.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png b/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png
new file mode 100644
index 0000000000000000000000000000000000000000..f13acbc996cd1e5931f988fb4ebae8a772295426
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_crypto.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_push.gif b/examples/MADDPG/.benchmark/MADDPG_simple_push.gif
new file mode 100644
index 0000000000000000000000000000000000000000..faad8196614c31f31b5aded2ab379357a5c3a5c3
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_push.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_push.png b/examples/MADDPG/.benchmark/MADDPG_simple_push.png
new file mode 100644
index 0000000000000000000000000000000000000000..a368f63f846f47cd2520e1af65db10ba1c53dc5c
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_push.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif b/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif
new file mode 100644
index 0000000000000000000000000000000000000000..e7cf37f9aa664e8347da9d8cb4c376e3f2798607
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_reference.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_reference.png b/examples/MADDPG/.benchmark/MADDPG_simple_reference.png
new file mode 100644
index 0000000000000000000000000000000000000000..ea94f14efb04227b3571c983505e480778b5a463
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_reference.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif
new file mode 100644
index 0000000000000000000000000000000000000000..6ef772bb966c247eff1d30da70dd6e969b49005a
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png
new file mode 100644
index 0000000000000000000000000000000000000000..ad4e0fc4d0bb446f26a445ce58e4f7d47def4eb5
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_speaker_listener.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif b/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif
new file mode 100644
index 0000000000000000000000000000000000000000..013f95d1d18e8dd90c7adb526aeeed475841cb9c
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_spread.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_spread.png b/examples/MADDPG/.benchmark/MADDPG_simple_spread.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ae93288ff4dd5fda1e35006b50102dda5695250
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_spread.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif b/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif
new file mode 100644
index 0000000000000000000000000000000000000000..703ca69f5e529f70be6fc45fef62871f39bff771
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_tag.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_tag.png b/examples/MADDPG/.benchmark/MADDPG_simple_tag.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d21b27d350bc6ebd4d427e5fa460eb8863e95d9
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_tag.png differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif
new file mode 100644
index 0000000000000000000000000000000000000000..1d6c4df5e76d4bf5dd43e501bea01eeccdaadd7e
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.gif differ
diff --git a/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png
new file mode 100644
index 0000000000000000000000000000000000000000..78f77fd7bc2f8bf4986aa20c241d8274c6fb9cde
Binary files /dev/null and b/examples/MADDPG/.benchmark/MADDPG_simple_world_comm.png differ
diff --git a/examples/MADDPG/README.md b/examples/MADDPG/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d37b553de416fe9176119f654508774d8bb1278c
--- /dev/null
+++ b/examples/MADDPG/README.md
@@ -0,0 +1,112 @@
+## Reproduce MADDPG with PARL
+Based on PARL, the MADDPG algorithm of deep reinforcement learning has been reproduced.
+
++ paper:
+[ Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments](https://arxiv.org/abs/1706.02275)
+
+### Multi-agent particle environment introduction
+A simple multi-agent particle world based on gym. Please see [here](https://github.com/openai/multiagent-particle-envs) to install and know more about the environment.
+
+### Benchmark result
+Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).
+
+
+
+
+simple
+
+ |
+
+simple_adversary
+
+ |
+
+simple_push
+
+ |
+
+simple_reference
+
+ |
+
+
+
+simple_speaker_listener
+
+ |
+
+simple_spread
+
+ |
+
+simple_tag
+
+ |
+
+simple_world_comm
+
+ |
+
+
+
+### Experiments result
+Display after 25000 episodes.
+
+
+
+
+simple
+
+ |
+
+simple_adversary
+
+ |
+
+simple_push
+
+ |
+
+simple_reference
+
+ |
+
+
+
+simple_speaker_listener
+
+ |
+
+simple_spread
+
+ |
+
+simple_tag
+
+ |
+
+simple_world_comm
+
+ |
+
+
+
+
+## How to use
+### Dependencies:
++ python3.5+
++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
++ [parl](https://github.com/PaddlePaddle/PARL)
++ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs)
++ gym
+
+### Start Training:
+```
+# To train an agent for simple_speaker_listener scenario
+python train.py
+
+# To train for other scenario, model is automatically saved every 1000 episodes
+# python train.py --env [ENV_NAME]
+
+# To show animation effects after training
+# python train.py --env [ENV_NAME] --show --restore
diff --git a/examples/MADDPG/simple_agent.py b/examples/MADDPG/simple_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0e47dd5b085643a3c9ae250b6826a783c71771
--- /dev/null
+++ b/examples/MADDPG/simple_agent.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import parl
+from parl import layers
+from paddle import fluid
+from parl.utils import ReplayMemory
+
+
+class MAAgent(parl.Agent):
+ def __init__(self,
+ algorithm,
+ agent_index=None,
+ obs_dim_n=None,
+ act_dim_n=None,
+ batch_size=None,
+ speedup=False):
+ assert isinstance(agent_index, int)
+ assert isinstance(obs_dim_n, list)
+ assert isinstance(act_dim_n, list)
+ assert isinstance(batch_size, int)
+ assert isinstance(speedup, bool)
+ self.agent_index = agent_index
+ self.obs_dim_n = obs_dim_n
+ self.act_dim_n = act_dim_n
+ self.batch_size = batch_size
+ self.speedup = speedup
+ self.n = len(act_dim_n)
+
+ self.memory_size = int(1e6)
+ self.min_memory_size = batch_size * 25 # batch_size * args.max_episode_len
+ self.rpm = ReplayMemory(
+ max_size=self.memory_size,
+ obs_dim=self.obs_dim_n[agent_index],
+ act_dim=self.act_dim_n[agent_index])
+ self.global_train_step = 0
+
+ super(MAAgent, self).__init__(algorithm)
+
+ # Attention: In the beginning, sync target model totally.
+ self.alg.sync_target(decay=0)
+
+ def build_program(self):
+ self.pred_program = fluid.Program()
+ self.learn_program = fluid.Program()
+ self.next_q_program = fluid.Program()
+ self.next_a_program = fluid.Program()
+
+ with fluid.program_guard(self.pred_program):
+ obs = layers.data(
+ name='obs',
+ shape=[self.obs_dim_n[self.agent_index]],
+ dtype='float32')
+ self.pred_act = self.alg.predict(obs)
+
+ with fluid.program_guard(self.learn_program):
+ obs_n = [
+ layers.data(
+ name='obs' + str(i),
+ shape=[self.obs_dim_n[i]],
+ dtype='float32') for i in range(self.n)
+ ]
+ act_n = [
+ layers.data(
+ name='act' + str(i),
+ shape=[self.act_dim_n[i]],
+ dtype='float32') for i in range(self.n)
+ ]
+ target_q = layers.data(name='target_q', shape=[], dtype='float32')
+ self.critic_cost = self.alg.learn(obs_n, act_n, target_q)
+
+ with fluid.program_guard(self.next_q_program):
+ obs_n = [
+ layers.data(
+ name='obs' + str(i),
+ shape=[self.obs_dim_n[i]],
+ dtype='float32') for i in range(self.n)
+ ]
+ act_n = [
+ layers.data(
+ name='act' + str(i),
+ shape=[self.act_dim_n[i]],
+ dtype='float32') for i in range(self.n)
+ ]
+ self.next_Q = self.alg.Q_next(obs_n, act_n)
+
+ with fluid.program_guard(self.next_a_program):
+ obs = layers.data(
+ name='obs',
+ shape=[self.obs_dim_n[self.agent_index]],
+ dtype='float32')
+ self.next_action = self.alg.predict_next(obs)
+
+ if self.speedup:
+ self.pred_program = parl.compile(self.pred_program)
+ self.learn_program = parl.compile(self.learn_program,
+ self.critic_cost)
+ self.next_q_program = parl.compile(self.next_q_program)
+ self.next_a_program = parl.compile(self.next_a_program)
+
+ def predict(self, obs):
+ obs = np.expand_dims(obs, axis=0)
+ obs = obs.astype('float32')
+ act = self.fluid_executor.run(
+ self.pred_program, feed={'obs': obs},
+ fetch_list=[self.pred_act])[0]
+ return act[0]
+
+ def learn(self, agents):
+ self.global_train_step += 1
+
+ # only update parameter every 100 steps
+ if self.global_train_step % 100 != 0:
+ return 0.0
+
+ if self.rpm.size() <= self.min_memory_size:
+ return 0.0
+
+ batch_obs_n = []
+ batch_act_n = []
+ batch_obs_new_n = []
+
+ rpm_sample_index = self.rpm.make_index(self.batch_size)
+ for i in range(self.n):
+ batch_obs, batch_act, _, batch_obs_new, _ \
+ = agents[i].rpm.sample_batch_by_index(rpm_sample_index)
+ batch_obs_n.append(batch_obs)
+ batch_act_n.append(batch_act)
+ batch_obs_new_n.append(batch_obs_new)
+ _, _, batch_rew, _, batch_isOver \
+ = self.rpm.sample_batch_by_index(rpm_sample_index)
+
+ # compute target q
+ target_q = 0.0
+ target_act_next_n = []
+ for i in range(self.n):
+ feed = {'obs': batch_obs_new_n[i]}
+ target_act_next = agents[i].fluid_executor.run(
+ agents[i].next_a_program,
+ feed=feed,
+ fetch_list=[agents[i].next_action])[0]
+ target_act_next_n.append(target_act_next)
+ feed_obs = {'obs' + str(i): batch_obs_new_n[i] for i in range(self.n)}
+ feed_act = {
+ 'act' + str(i): target_act_next_n[i]
+ for i in range(self.n)
+ }
+ feed = feed_obs.copy()
+ feed.update(feed_act) # merge two dict
+ target_q_next = self.fluid_executor.run(
+ self.next_q_program, feed=feed, fetch_list=[self.next_Q])[0]
+ target_q += (
+ batch_rew + self.alg.gamma * (1.0 - batch_isOver) * target_q_next)
+
+ feed_obs = {'obs' + str(i): batch_obs_n[i] for i in range(self.n)}
+ feed_act = {'act' + str(i): batch_act_n[i] for i in range(self.n)}
+ target_q = target_q.astype('float32')
+ feed = feed_obs.copy()
+ feed.update(feed_act)
+ feed['target_q'] = target_q
+ critic_cost = self.fluid_executor.run(
+ self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0]
+
+ self.alg.sync_target()
+ return critic_cost
+
+ def add_experience(self, obs, act, reward, next_obs, terminal):
+ self.rpm.append(obs, act, reward, next_obs, terminal)
diff --git a/examples/MADDPG/simple_model.py b/examples/MADDPG/simple_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..00e130dded41486bc541a58bdc5967e4b30774fb
--- /dev/null
+++ b/examples/MADDPG/simple_model.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle.fluid as fluid
+import parl
+from parl import layers
+
+
+class MAModel(parl.Model):
+ def __init__(self, act_dim):
+ self.actor_model = ActorModel(act_dim)
+ self.critic_model = CriticModel()
+
+ def policy(self, obs):
+ return self.actor_model.policy(obs)
+
+ def value(self, obs, act):
+ return self.critic_model.value(obs, act)
+
+ def get_actor_params(self):
+ return self.actor_model.parameters()
+
+ def get_critic_params(self):
+ return self.critic_model.parameters()
+
+
+class ActorModel(parl.Model):
+ def __init__(self, act_dim):
+ hid1_size = 64
+ hid2_size = 64
+
+ self.fc1 = layers.fc(
+ size=hid1_size,
+ act='relu',
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+ self.fc2 = layers.fc(
+ size=hid2_size,
+ act='relu',
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+ self.fc3 = layers.fc(
+ size=act_dim,
+ act=None,
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+
+ def policy(self, obs):
+ hid1 = self.fc1(obs)
+ hid2 = self.fc2(hid1)
+ means = self.fc3(hid2)
+ means = means
+ return means
+
+
+class CriticModel(parl.Model):
+ def __init__(self):
+ hid1_size = 64
+ hid2_size = 64
+
+ self.fc1 = layers.fc(
+ size=hid1_size,
+ act='relu',
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+ self.fc2 = layers.fc(
+ size=hid2_size,
+ act='relu',
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+ self.fc3 = layers.fc(
+ size=1,
+ act=None,
+ param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
+
+ def value(self, obs_n, act_n):
+ inputs = layers.concat(obs_n + act_n, axis=1)
+ hid1 = self.fc1(inputs)
+ hid2 = self.fc2(hid1)
+ Q = self.fc3(hid2)
+ Q = layers.squeeze(Q, axes=[1])
+ return Q
diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e20dcdb4fd35638432fb1666b76b30c2a388d8
--- /dev/null
+++ b/examples/MADDPG/train.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import argparse
+import numpy as np
+from simple_model import MAModel
+from simple_agent import MAAgent
+import parl
+from parl.env.multiagent_simple_env import MAenv
+from parl.utils import logger, tensorboard
+
+
+def run_episode(env, agents):
+ obs_n = env.reset()
+ total_reward = 0
+ agents_reward = [0 for _ in range(env.n)]
+ steps = 0
+ while True:
+ steps += 1
+ action_n = [agent.predict(obs) for agent, obs in zip(agents, obs_n)]
+ next_obs_n, reward_n, done_n, _ = env.step(action_n)
+ done = all(done_n)
+ terminal = (steps >= args.max_step_per_episode)
+
+ # store experience
+ for i, agent in enumerate(agents):
+ agent.add_experience(obs_n[i], action_n[i], reward_n[i],
+ next_obs_n[i], done_n[i])
+
+ # compute reward of every agent
+ obs_n = next_obs_n
+ for i, reward in enumerate(reward_n):
+ total_reward += reward
+ agents_reward[i] += reward
+
+ # check the end of an episode
+ if done or terminal:
+ break
+
+ # show animation
+ if args.show:
+ time.sleep(0.1)
+ env.render()
+
+ # show model effect without training
+ if args.restore and args.show:
+ continue
+
+ # learn policy
+ for i, agent in enumerate(agents):
+ critic_loss = agent.learn(agents)
+ tensorboard.add_scalar('critic_loss_%d' % i, critic_loss,
+ agent.global_train_step)
+
+ return total_reward, agents_reward, steps
+
+
+def train_agent():
+ env = MAenv(args.env)
+ logger.info('agent num: {}'.format(env.n))
+ logger.info('observation_space: {}'.format(env.observation_space))
+ logger.info('action_space: {}'.format(env.action_space))
+ logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
+ logger.info('act_shape_n: {}'.format(env.act_shape_n))
+ for i in range(env.n):
+ logger.info('agent {} obs_low:{} obs_high:{}'.format(
+ i, env.observation_space[i].low, env.observation_space[i].high))
+ logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i]))
+ if ('low' in dir(env.action_space[i])):
+ logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format(
+ i, env.action_space[i].low, env.action_space[i].high,
+ env.action_space[i].shape))
+ logger.info('num_discrete_space:{}'.format(
+ env.action_space[i].num_discrete_space))
+
+ from gym import spaces
+ from multiagent.multi_discrete import MultiDiscrete
+ for space in env.action_space:
+ assert (isinstance(space, spaces.Discrete)
+ or isinstance(space, MultiDiscrete))
+
+ agents = []
+ for i in range(env.n):
+ model = MAModel(env.act_shape_n[i])
+ algorithm = parl.algorithms.MADDPG(
+ model,
+ agent_index=i,
+ act_space=env.action_space,
+ gamma=args.gamma,
+ tau=args.tau,
+ lr=args.lr)
+ agent = MAAgent(
+ algorithm,
+ agent_index=i,
+ obs_dim_n=env.obs_shape_n,
+ act_dim_n=env.act_shape_n,
+ batch_size=args.batch_size,
+ speedup=(not args.restore))
+ agents.append(agent)
+ total_steps = 0
+ total_episodes = 0
+
+ episode_rewards = [] # sum of rewards for all agents
+ agent_rewards = [[] for _ in range(env.n)] # individual agent reward
+ final_ep_rewards = [] # sum of rewards for training curve
+ final_ep_ag_rewards = [] # agent rewards for training curve
+
+ if args.restore:
+ # restore modle
+ for i in range(len(agents)):
+ model_file = args.model_dir + '/agent_' + str(i) + '.ckpt'
+ if not os.path.exists(model_file):
+ logger.info('model file {} does not exits'.format(model_file))
+ raise Exception
+ agents[i].restore(model_file)
+
+ t_start = time.time()
+ logger.info('Starting...')
+ while total_episodes <= args.max_episodes:
+ # run an episode
+ ep_reward, ep_agent_rewards, steps = run_episode(env, agents)
+ if args.show:
+ print('episode {}, reward {}, steps {}'.format(
+ total_episodes, ep_reward, steps))
+
+ # Record reward
+ total_steps += steps
+ total_episodes += 1
+ episode_rewards.append(ep_reward)
+ for i in range(env.n):
+ agent_rewards[i].append(ep_agent_rewards[i])
+
+ # Keep track of final episode reward
+ if total_episodes % args.stat_rate == 0:
+ mean_episode_reward = np.mean(episode_rewards[-args.stat_rate:])
+ final_ep_rewards.append(mean_episode_reward)
+ for rew in agent_rewards:
+ final_ep_ag_rewards.append(np.mean(rew[-args.stat_rate:]))
+ use_time = round(time.time() - t_start, 3)
+ logger.info(
+ 'Steps: {}, Episodes: {}, Mean episode reward: {}, Time: {}'.
+ format(total_steps, total_episodes, mean_episode_reward,
+ use_time))
+ t_start = time.time()
+ tensorboard.add_scalar('mean_episode_reward/episode',
+ mean_episode_reward, total_episodes)
+ tensorboard.add_scalar('mean_episode_reward/steps',
+ mean_episode_reward, total_steps)
+ tensorboard.add_scalar('use_time/1000episode', use_time,
+ total_episodes)
+
+ # save model
+ if not args.restore:
+ os.makedirs(os.path.dirname(args.model_dir), exist_ok=True)
+ for i in range(len(agents)):
+ model_name = '/agent_' + str(i) + '.ckpt'
+ agents[i].save(args.model_dir + model_name)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ # Environment
+ parser.add_argument(
+ '--env',
+ type=str,
+ default='simple_speaker_listener',
+ help='scenario of MultiAgentEnv')
+ parser.add_argument(
+ '--max_step_per_episode',
+ type=int,
+ default=25,
+ help='maximum step per episode')
+ parser.add_argument(
+ '--max_episodes',
+ type=int,
+ default=25000,
+ help='stop condition:number of episodes')
+ parser.add_argument(
+ '--stat_rate',
+ type=int,
+ default=1000,
+ help='statistical interval of save model or count reward')
+ # Core training parameters
+ parser.add_argument(
+ '--lr',
+ type=float,
+ default=1e-3,
+ help='learning rate for Adam optimizer')
+ parser.add_argument(
+ '--gamma', type=float, default=0.95, help='discount factor')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=1024,
+ help='number of episodes to optimize at the same time')
+ parser.add_argument('--tau', type=int, default=0.01, help='soft update')
+ # auto save model, optional restore model
+ parser.add_argument(
+ '--show', action='store_true', default=False, help='display or not')
+ parser.add_argument(
+ '--restore',
+ action='store_true',
+ default=False,
+ help='restore or not, must have model_dir')
+ parser.add_argument(
+ '--model_dir',
+ type=str,
+ default='./model',
+ help='directory for saving model')
+
+ args = parser.parse_args()
+
+ train_agent()
diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py
index 3f005ac623ed5cf7fc95b2f9242c962dbfbf3302..dac58ab6bf7815e95ecc21269b67116d567fa576 100644
--- a/parl/algorithms/fluid/__init__.py
+++ b/parl/algorithms/fluid/__init__.py
@@ -14,6 +14,7 @@
from parl.algorithms.fluid.a3c import *
from parl.algorithms.fluid.ddpg import *
+from parl.algorithms.fluid.maddpg import *
from parl.algorithms.fluid.dqn import *
from parl.algorithms.fluid.ddqn import *
from parl.algorithms.fluid.policy_gradient import *
diff --git a/parl/algorithms/fluid/maddpg.py b/parl/algorithms/fluid/maddpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb166090c4d727195a775f2ac1cf956cce13922
--- /dev/null
+++ b/parl/algorithms/fluid/maddpg.py
@@ -0,0 +1,159 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+warnings.simplefilter('default')
+
+from parl.core.fluid import layers
+from copy import deepcopy
+from paddle import fluid
+from parl.core.fluid.algorithm import Algorithm
+
+__all__ = ['MADDPG']
+
+from gym import spaces
+from parl.core.fluid.policy_distribution import SoftCategoricalDistribution
+from parl.core.fluid.policy_distribution import SoftMultiCategoricalDistribution
+
+
+def SoftPDistribution(logits, act_space):
+ if (isinstance(act_space, spaces.Discrete)):
+ return SoftCategoricalDistribution(logits)
+ # is instance of multiagent.multi_discrete.MultiDiscrete
+ elif (hasattr(act_space, 'num_discrete_space')):
+ return SoftMultiCategoricalDistribution(logits, act_space.low,
+ act_space.high)
+ else:
+ raise NotImplementedError
+
+
+class MADDPG(Algorithm):
+ def __init__(self,
+ model,
+ agent_index=None,
+ act_space=None,
+ gamma=None,
+ tau=None,
+ lr=None):
+ """ MADDPG algorithm
+
+ Args:
+ model (parl.Model): forward network of actor and critic.
+ The function get_actor_params() of model should be implemented.
+ agent_index: index of agent, in multiagent env
+ act_space: action_space, gym space
+ gamma (float): discounted factor for reward computation.
+ tau (float): decay coefficient when updating the weights of self.target_model with self.model
+ lr (float): learning rate
+ """
+
+ assert isinstance(agent_index, int)
+ assert isinstance(act_space, list)
+ assert isinstance(gamma, float)
+ assert isinstance(tau, float)
+ assert isinstance(lr, float)
+ self.agent_index = agent_index
+ self.act_space = act_space
+ self.gamma = gamma
+ self.tau = tau
+ self.lr = lr
+
+ self.model = model
+ self.target_model = deepcopy(model)
+
+ def predict(self, obs):
+ """ input:
+ obs: observation, shape([B] + shape of obs_n[agent_index])
+ output:
+ act: action, shape([B] + shape of act_n[agent_index])
+ """
+ this_policy = self.model.policy(obs)
+ this_action = SoftPDistribution(
+ logits=this_policy,
+ act_space=self.act_space[self.agent_index]).sample()
+ return this_action
+
+ def predict_next(self, obs):
+ """ input: observation, shape([B] + shape of obs_n[agent_index])
+ output: action, shape([B] + shape of act_n[agent_index])
+ """
+ next_policy = self.target_model.policy(obs)
+ next_action = SoftPDistribution(
+ logits=next_policy,
+ act_space=self.act_space[self.agent_index]).sample()
+ return next_action
+
+ def Q(self, obs_n, act_n):
+ """ input:
+ obs_n: all agents' observation, shape([B] + shape of obs_n)
+ output:
+ act_n: all agents' action, shape([B] + shape of act_n)
+ """
+ return self.model.value(obs_n, act_n)
+
+ def Q_next(self, obs_n, act_n):
+ """ input:
+ obs_n: all agents' observation, shape([B] + shape of obs_n)
+ output:
+ act_n: all agents' action, shape([B] + shape of act_n)
+ """
+ return self.target_model.value(obs_n, act_n)
+
+ def learn(self, obs_n, act_n, target_q):
+ """ update actor and critic model with MADDPG algorithm
+ """
+ actor_cost = self._actor_learn(obs_n, act_n)
+ critic_cost = self._critic_learn(obs_n, act_n, target_q)
+ return critic_cost
+
+ def _actor_learn(self, obs_n, act_n):
+ i = self.agent_index
+ this_policy = self.model.policy(obs_n[i])
+ sample_this_action = SoftPDistribution(
+ logits=this_policy,
+ act_space=self.act_space[self.agent_index]).sample()
+
+ action_input_n = act_n + []
+ action_input_n[i] = sample_this_action
+ eval_q = self.Q(obs_n, action_input_n)
+ act_cost = layers.reduce_mean(-1.0 * eval_q)
+
+ act_reg = layers.reduce_mean(layers.square(this_policy))
+
+ cost = act_cost + act_reg * 1e-3
+
+ fluid.clip.set_gradient_clip(
+ clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
+ param_list=self.model.get_actor_params())
+
+ optimizer = fluid.optimizer.AdamOptimizer(self.lr)
+ optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
+ return cost
+
+ def _critic_learn(self, obs_n, act_n, target_q):
+ pred_q = self.Q(obs_n, act_n)
+ cost = layers.reduce_mean(layers.square_error_cost(pred_q, target_q))
+
+ fluid.clip.set_gradient_clip(
+ clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
+ param_list=self.model.get_critic_params())
+
+ optimizer = fluid.optimizer.AdamOptimizer(self.lr)
+ optimizer.minimize(cost, parameter_list=self.model.get_critic_params())
+ return cost
+
+ def sync_target(self, decay=None):
+ if decay is None:
+ decay = 1.0 - self.tau
+ self.model.sync_weights_to(self.target_model, decay=decay)
diff --git a/parl/core/fluid/policy_distribution.py b/parl/core/fluid/policy_distribution.py
index 4d876062a3e19ab0101a83b3d325de6200932e67..73bc93b78ba25aa2f0c197fba868e2269afe4fa1 100644
--- a/parl/core/fluid/policy_distribution.py
+++ b/parl/core/fluid/policy_distribution.py
@@ -14,7 +14,10 @@
from parl.core.fluid import layers
-__all__ = ['PolicyDistribution', 'CategoricalDistribution']
+__all__ = [
+ 'PolicyDistribution', 'CategoricalDistribution',
+ 'SoftCategoricalDistribution', 'SoftMultiCategoricalDistribution'
+]
class PolicyDistribution(object):
@@ -79,7 +82,6 @@ class CategoricalDistribution(PolicyDistribution):
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
-
assert len(actions.shape) == 1
logits = self.logits - layers.reduce_max(self.logits, dim=1)
@@ -122,3 +124,88 @@ class CategoricalDistribution(PolicyDistribution):
(logits - layers.log(z) - other_logits + layers.log(other_z)),
dim=1)
return kl
+
+
+class SoftCategoricalDistribution(CategoricalDistribution):
+ """Categorical distribution with noise for discrete action spaces"""
+
+ def __init__(self, logits):
+ """
+ Args:
+ logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
+ """
+ self.logits = logits
+ super(SoftCategoricalDistribution, self).__init__(logits)
+
+ def sample(self):
+ """
+ Returns:
+ sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
+ with noise to keep the target close to the original action.
+ """
+ eps = 1e-4
+ logits_shape = layers.cast(layers.shape(self.logits), dtype='int64')
+ uniform = layers.uniform_random(logits_shape, min=eps, max=1.0 - eps)
+ soft_uniform = layers.log(-1.0 * layers.log(uniform))
+ return layers.softmax(self.logits - soft_uniform, axis=-1)
+
+
+class SoftMultiCategoricalDistribution(PolicyDistribution):
+ """Categorical distribution with noise for MultiDiscrete action spaces."""
+
+ def __init__(self, logits, low, high):
+ """
+ Args:
+ logits: A float32 tensor with shape [BATCH_SIZE, LEN_MultiDiscrete, NUM_ACTIONS] of unnormalized policy logits
+ low: lower bounds of sample action
+ high: Upper bounds of action
+ """
+ self.logits = logits
+ self.low = low
+ self.high = high
+ self.categoricals = list(
+ map(
+ SoftCategoricalDistribution,
+ layers.split(
+ input=logits,
+ num_or_sections=list(high - low + 1),
+ dim=len(logits.shape) - 1)))
+
+ def sample(self):
+ """
+ Returns:
+ sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
+ with noise to keep the target close to the original action.
+ """
+ cate_list = []
+ for i in range(len(self.categoricals)):
+ cate_list.append(self.low[i] + self.categoricals[i].sample())
+ return layers.concat(cate_list, axis=-1)
+
+ def layers_add_n(self, input_list):
+ """
+ Adds all input tensors element-wise, can replace tf.add_n
+ """
+ assert len(input_list) >= 1
+ res = input_list[0]
+ for i in range(1, len(input_list)):
+ res = layers.elementwise_add(res, input_list[i])
+ return res
+
+ def entropy(self):
+ """
+ Returns:
+ entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
+ """
+ return self.layers_add_n([p.entropy() for p in self.categoricals])
+
+ def kl(self, other):
+ """
+ Args:
+ other: object of SoftCategoricalDistribution
+
+ Returns:
+ kl: A float32 tensor with shape [BATCH_SIZE]
+ """
+ return self.layers_add_n(
+ [p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
diff --git a/parl/env/multiagent_simple_env.py b/parl/env/multiagent_simple_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..222c4f7a861d5e1516619783fd94f8a1866a64c2
--- /dev/null
+++ b/parl/env/multiagent_simple_env.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from gym import spaces
+from multiagent.multi_discrete import MultiDiscrete
+from multiagent.environment import MultiAgentEnv
+import multiagent.scenarios as scenarios
+
+
+class MAenv(MultiAgentEnv):
+ """ multiagent environment warppers for maddpg
+ """
+
+ def __init__(self, scenario_name):
+ # load scenario from script
+ scenario = scenarios.load(scenario_name + ".py").Scenario()
+ # create world
+ world = scenario.make_world()
+ # initial multiagent environment
+ super().__init__(world, scenario.reset_world, scenario.reward,
+ scenario.observation)
+ self.obs_shape_n = [
+ self.get_shape(self.observation_space[i]) for i in range(self.n)
+ ]
+ self.act_shape_n = [
+ self.get_shape(self.action_space[i]) for i in range(self.n)
+ ]
+
+ def get_shape(self, input_space):
+ """
+ Args:
+ input_space: environment space
+
+ Returns:
+ space shape
+ """
+ if (isinstance(input_space, spaces.Box)):
+ if (len(input_space.shape) == 1):
+ return input_space.shape[0]
+ else:
+ return input_space.shape
+ elif (isinstance(input_space, spaces.Discrete)):
+ return input_space.n
+ elif (isinstance(input_space, MultiDiscrete)):
+ return sum(input_space.high - input_space.low + 1)
+ else:
+ print('[Error] shape is {}, not Box or Discrete or MultiDiscrete'.
+ format(input_space.shape))
+ raise NotImplementedError
diff --git a/parl/utils/replay_memory.py b/parl/utils/replay_memory.py
index 051c6fb30f94b933bdafcfad16d284a1a1610432..a1a3b4a8ed2cb327b7a4ec34770ea09ccf917ad8 100755
--- a/parl/utils/replay_memory.py
+++ b/parl/utils/replay_memory.py
@@ -34,10 +34,8 @@ class ReplayMemory(object):
self._curr_pos = 0
def sample_batch(self, batch_size):
- # index mapping to avoid sampling saving example
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
- batch_idx = (self._curr_pos + 300 + batch_idx) % self._curr_size
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
@@ -46,6 +44,19 @@ class ReplayMemory(object):
terminal = self.terminal[batch_idx]
return obs, action, reward, next_obs, terminal
+ def make_index(self, batch_size):
+ batch_idx = np.random.randint(
+ self._curr_size - 300 - 1, size=batch_size)
+ return batch_idx
+
+ def sample_batch_by_index(self, batch_idx):
+ obs = self.obs[batch_idx]
+ reward = self.reward[batch_idx]
+ action = self.action[batch_idx]
+ next_obs = self.next_obs[batch_idx]
+ terminal = self.terminal[batch_idx]
+ return obs, action, reward, next_obs, terminal
+
def append(self, obs, act, reward, next_obs, terminal):
if self._curr_size < self.max_size:
self._curr_size += 1