未验证 提交 9216d941 编写于 作者: R rical730 提交者: GitHub

add maddpg example (#200)

* add maddpg example

* format with yapf

* fix coding style

* fix coding style

* unittest without import multiagent env

* update maddpg code

* update maddpg readme

* add copyright comments
上级 f35200fe
......@@ -78,6 +78,7 @@ pip install parl
- [A2C](examples/A2C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
- [MADDPG](examples/MADDPG/)
- [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/)
......
......@@ -81,6 +81,7 @@ pip install parl
- [A2C](examples/A2C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
- [MADDPG](examples/MADDPG/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/)
......
## Reproduce MADDPG with PARL
Based on PARL, the MADDPG algorithm of deep reinforcement learning has been reproduced.
+ paper:
[ Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments](https://arxiv.org/abs/1706.02275)
### Multi-agent particle environment introduction
A simple multi-agent particle world based on gym. Please see [here](https://github.com/openai/multiagent-particle-envs) to install and know more about the environment.
### Benchmark result
Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).
<table>
<tr>
<td>
simple<br>
<img src=".benchmark/MADDPG_simple.png" width = "170" height = "170" alt="MADDPG_simple"/>
</td>
<td>
simple_adversary<br>
<img src=".benchmark/MADDPG_simple_adversary.png" width = "170" height = "170" alt="MADDPG_simple_adversary"/>
</td>
<td>
simple_push<br>
<img src=".benchmark/MADDPG_simple_push.png" width = "170" height = "170" alt="MADDPG_simple_push"/>
</td>
<td>
simple_reference<br>
<img src=".benchmark/MADDPG_simple_reference.png" width = "170" height = "170" alt="MADDPG_simple_reference"/>
</td>
</tr>
<tr>
<td>
simple_speaker_listener<br>
<img src=".benchmark/MADDPG_simple_speaker_listener.png" width = "170" height = "170" alt="MADDPG_simple_speaker_listener"/>
</td>
<td>
simple_spread<br>
<img src=".benchmark/MADDPG_simple_spread.png" width = "170" height = "170" alt="MADDPG_simple_spread"/>
</td>
<td>
simple_tag<br>
<img src=".benchmark/MADDPG_simple_tag.png" width = "170" height = "170" alt="MADDPG_simple_tag"/>
</td>
<td>
simple_world_comm<br>
<img src=".benchmark/MADDPG_simple_world_comm.png" width = "170" height = "170" alt="MADDPG_simple_world_comm"/>
</td>
</tr>
</table>
### Experiments result
Display after 25000 episodes.
<table>
<tr>
<td>
simple<br>
<img src=".benchmark/MADDPG_simple.gif" width = "170" height = "170" alt="MADDPG_simple"/>
</td>
<td>
simple_adversary<br>
<img src=".benchmark/MADDPG_simple_adversary.gif" width = "170" height = "170" alt="MADDPG_simple_adversary"/>
</td>
<td>
simple_push<br>
<img src=".benchmark/MADDPG_simple_push.gif" width = "170" height = "170" alt="MADDPG_simple_push"/>
</td>
<td>
simple_reference<br>
<img src=".benchmark/MADDPG_simple_reference.gif" width = "170" height = "170" alt="MADDPG_simple_reference"/>
</td>
</tr>
<tr>
<td>
simple_speaker_listener<br>
<img src=".benchmark/MADDPG_simple_speaker_listener.gif" width = "170" height = "170" alt="MADDPG_simple_speaker_listener"/>
</td>
<td>
simple_spread<br>
<img src=".benchmark/MADDPG_simple_spread.gif" width = "170" height = "170" alt="MADDPG_simple_spread"/>
</td>
<td>
simple_tag<br>
<img src=".benchmark/MADDPG_simple_tag.gif" width = "170" height = "170" alt="MADDPG_simple_tag"/>
</td>
<td>
simple_world_comm<br>
<img src=".benchmark/MADDPG_simple_world_comm.gif" width = "170" height = "170" alt="MADDPG_simple_world_comm"/>
</td>
</tr>
</table>
## How to use
### Dependencies:
+ python3.5+
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs)
+ gym
### Start Training:
```
# To train an agent for simple_speaker_listener scenario
python train.py
# To train for other scenario, model is automatically saved every 1000 episodes
# python train.py --env [ENV_NAME]
# To show animation effects after training
# python train.py --env [ENV_NAME] --show --restore
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl
from parl import layers
from paddle import fluid
from parl.utils import ReplayMemory
class MAAgent(parl.Agent):
def __init__(self,
algorithm,
agent_index=None,
obs_dim_n=None,
act_dim_n=None,
batch_size=None,
speedup=False):
assert isinstance(agent_index, int)
assert isinstance(obs_dim_n, list)
assert isinstance(act_dim_n, list)
assert isinstance(batch_size, int)
assert isinstance(speedup, bool)
self.agent_index = agent_index
self.obs_dim_n = obs_dim_n
self.act_dim_n = act_dim_n
self.batch_size = batch_size
self.speedup = speedup
self.n = len(act_dim_n)
self.memory_size = int(1e6)
self.min_memory_size = batch_size * 25 # batch_size * args.max_episode_len
self.rpm = ReplayMemory(
max_size=self.memory_size,
obs_dim=self.obs_dim_n[agent_index],
act_dim=self.act_dim_n[agent_index])
self.global_train_step = 0
super(MAAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
self.next_q_program = fluid.Program()
self.next_a_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs',
shape=[self.obs_dim_n[self.agent_index]],
dtype='float32')
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs_n = [
layers.data(
name='obs' + str(i),
shape=[self.obs_dim_n[i]],
dtype='float32') for i in range(self.n)
]
act_n = [
layers.data(
name='act' + str(i),
shape=[self.act_dim_n[i]],
dtype='float32') for i in range(self.n)
]
target_q = layers.data(name='target_q', shape=[], dtype='float32')
self.critic_cost = self.alg.learn(obs_n, act_n, target_q)
with fluid.program_guard(self.next_q_program):
obs_n = [
layers.data(
name='obs' + str(i),
shape=[self.obs_dim_n[i]],
dtype='float32') for i in range(self.n)
]
act_n = [
layers.data(
name='act' + str(i),
shape=[self.act_dim_n[i]],
dtype='float32') for i in range(self.n)
]
self.next_Q = self.alg.Q_next(obs_n, act_n)
with fluid.program_guard(self.next_a_program):
obs = layers.data(
name='obs',
shape=[self.obs_dim_n[self.agent_index]],
dtype='float32')
self.next_action = self.alg.predict_next(obs)
if self.speedup:
self.pred_program = parl.compile(self.pred_program)
self.learn_program = parl.compile(self.learn_program,
self.critic_cost)
self.next_q_program = parl.compile(self.next_q_program)
self.next_a_program = parl.compile(self.next_a_program)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
obs = obs.astype('float32')
act = self.fluid_executor.run(
self.pred_program, feed={'obs': obs},
fetch_list=[self.pred_act])[0]
return act[0]
def learn(self, agents):
self.global_train_step += 1
# only update parameter every 100 steps
if self.global_train_step % 100 != 0:
return 0.0
if self.rpm.size() <= self.min_memory_size:
return 0.0
batch_obs_n = []
batch_act_n = []
batch_obs_new_n = []
rpm_sample_index = self.rpm.make_index(self.batch_size)
for i in range(self.n):
batch_obs, batch_act, _, batch_obs_new, _ \
= agents[i].rpm.sample_batch_by_index(rpm_sample_index)
batch_obs_n.append(batch_obs)
batch_act_n.append(batch_act)
batch_obs_new_n.append(batch_obs_new)
_, _, batch_rew, _, batch_isOver \
= self.rpm.sample_batch_by_index(rpm_sample_index)
# compute target q
target_q = 0.0
target_act_next_n = []
for i in range(self.n):
feed = {'obs': batch_obs_new_n[i]}
target_act_next = agents[i].fluid_executor.run(
agents[i].next_a_program,
feed=feed,
fetch_list=[agents[i].next_action])[0]
target_act_next_n.append(target_act_next)
feed_obs = {'obs' + str(i): batch_obs_new_n[i] for i in range(self.n)}
feed_act = {
'act' + str(i): target_act_next_n[i]
for i in range(self.n)
}
feed = feed_obs.copy()
feed.update(feed_act) # merge two dict
target_q_next = self.fluid_executor.run(
self.next_q_program, feed=feed, fetch_list=[self.next_Q])[0]
target_q += (
batch_rew + self.alg.gamma * (1.0 - batch_isOver) * target_q_next)
feed_obs = {'obs' + str(i): batch_obs_n[i] for i in range(self.n)}
feed_act = {'act' + str(i): batch_act_n[i] for i in range(self.n)}
target_q = target_q.astype('float32')
feed = feed_obs.copy()
feed.update(feed_act)
feed['target_q'] = target_q
critic_cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0]
self.alg.sync_target()
return critic_cost
def add_experience(self, obs, act, reward, next_obs, terminal):
self.rpm.append(obs, act, reward, next_obs, terminal)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class MAModel(parl.Model):
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
def policy(self, obs):
return self.actor_model.policy(obs)
def value(self, obs, act):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameters()
def get_critic_params(self):
return self.critic_model.parameters()
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 64
hid2_size = 64
self.fc1 = layers.fc(
size=hid1_size,
act='relu',
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
self.fc2 = layers.fc(
size=hid2_size,
act='relu',
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
self.fc3 = layers.fc(
size=act_dim,
act=None,
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.fc3(hid2)
means = means
return means
class CriticModel(parl.Model):
def __init__(self):
hid1_size = 64
hid2_size = 64
self.fc1 = layers.fc(
size=hid1_size,
act='relu',
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
self.fc2 = layers.fc(
size=hid2_size,
act='relu',
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
self.fc3 = layers.fc(
size=1,
act=None,
param_attr=fluid.initializer.Normal(loc=0.0, scale=0.1))
def value(self, obs_n, act_n):
inputs = layers.concat(obs_n + act_n, axis=1)
hid1 = self.fc1(inputs)
hid2 = self.fc2(hid1)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import argparse
import numpy as np
from simple_model import MAModel
from simple_agent import MAAgent
import parl
from parl.env.multiagent_simple_env import MAenv
from parl.utils import logger, tensorboard
def run_episode(env, agents):
obs_n = env.reset()
total_reward = 0
agents_reward = [0 for _ in range(env.n)]
steps = 0
while True:
steps += 1
action_n = [agent.predict(obs) for agent, obs in zip(agents, obs_n)]
next_obs_n, reward_n, done_n, _ = env.step(action_n)
done = all(done_n)
terminal = (steps >= args.max_step_per_episode)
# store experience
for i, agent in enumerate(agents):
agent.add_experience(obs_n[i], action_n[i], reward_n[i],
next_obs_n[i], done_n[i])
# compute reward of every agent
obs_n = next_obs_n
for i, reward in enumerate(reward_n):
total_reward += reward
agents_reward[i] += reward
# check the end of an episode
if done or terminal:
break
# show animation
if args.show:
time.sleep(0.1)
env.render()
# show model effect without training
if args.restore and args.show:
continue
# learn policy
for i, agent in enumerate(agents):
critic_loss = agent.learn(agents)
tensorboard.add_scalar('critic_loss_%d' % i, critic_loss,
agent.global_train_step)
return total_reward, agents_reward, steps
def train_agent():
env = MAenv(args.env)
logger.info('agent num: {}'.format(env.n))
logger.info('observation_space: {}'.format(env.observation_space))
logger.info('action_space: {}'.format(env.action_space))
logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
logger.info('act_shape_n: {}'.format(env.act_shape_n))
for i in range(env.n):
logger.info('agent {} obs_low:{} obs_high:{}'.format(
i, env.observation_space[i].low, env.observation_space[i].high))
logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i]))
if ('low' in dir(env.action_space[i])):
logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format(
i, env.action_space[i].low, env.action_space[i].high,
env.action_space[i].shape))
logger.info('num_discrete_space:{}'.format(
env.action_space[i].num_discrete_space))
from gym import spaces
from multiagent.multi_discrete import MultiDiscrete
for space in env.action_space:
assert (isinstance(space, spaces.Discrete)
or isinstance(space, MultiDiscrete))
agents = []
for i in range(env.n):
model = MAModel(env.act_shape_n[i])
algorithm = parl.algorithms.MADDPG(
model,
agent_index=i,
act_space=env.action_space,
gamma=args.gamma,
tau=args.tau,
lr=args.lr)
agent = MAAgent(
algorithm,
agent_index=i,
obs_dim_n=env.obs_shape_n,
act_dim_n=env.act_shape_n,
batch_size=args.batch_size,
speedup=(not args.restore))
agents.append(agent)
total_steps = 0
total_episodes = 0
episode_rewards = [] # sum of rewards for all agents
agent_rewards = [[] for _ in range(env.n)] # individual agent reward
final_ep_rewards = [] # sum of rewards for training curve
final_ep_ag_rewards = [] # agent rewards for training curve
if args.restore:
# restore modle
for i in range(len(agents)):
model_file = args.model_dir + '/agent_' + str(i) + '.ckpt'
if not os.path.exists(model_file):
logger.info('model file {} does not exits'.format(model_file))
raise Exception
agents[i].restore(model_file)
t_start = time.time()
logger.info('Starting...')
while total_episodes <= args.max_episodes:
# run an episode
ep_reward, ep_agent_rewards, steps = run_episode(env, agents)
if args.show:
print('episode {}, reward {}, steps {}'.format(
total_episodes, ep_reward, steps))
# Record reward
total_steps += steps
total_episodes += 1
episode_rewards.append(ep_reward)
for i in range(env.n):
agent_rewards[i].append(ep_agent_rewards[i])
# Keep track of final episode reward
if total_episodes % args.stat_rate == 0:
mean_episode_reward = np.mean(episode_rewards[-args.stat_rate:])
final_ep_rewards.append(mean_episode_reward)
for rew in agent_rewards:
final_ep_ag_rewards.append(np.mean(rew[-args.stat_rate:]))
use_time = round(time.time() - t_start, 3)
logger.info(
'Steps: {}, Episodes: {}, Mean episode reward: {}, Time: {}'.
format(total_steps, total_episodes, mean_episode_reward,
use_time))
t_start = time.time()
tensorboard.add_scalar('mean_episode_reward/episode',
mean_episode_reward, total_episodes)
tensorboard.add_scalar('mean_episode_reward/steps',
mean_episode_reward, total_steps)
tensorboard.add_scalar('use_time/1000episode', use_time,
total_episodes)
# save model
if not args.restore:
os.makedirs(os.path.dirname(args.model_dir), exist_ok=True)
for i in range(len(agents)):
model_name = '/agent_' + str(i) + '.ckpt'
agents[i].save(args.model_dir + model_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Environment
parser.add_argument(
'--env',
type=str,
default='simple_speaker_listener',
help='scenario of MultiAgentEnv')
parser.add_argument(
'--max_step_per_episode',
type=int,
default=25,
help='maximum step per episode')
parser.add_argument(
'--max_episodes',
type=int,
default=25000,
help='stop condition:number of episodes')
parser.add_argument(
'--stat_rate',
type=int,
default=1000,
help='statistical interval of save model or count reward')
# Core training parameters
parser.add_argument(
'--lr',
type=float,
default=1e-3,
help='learning rate for Adam optimizer')
parser.add_argument(
'--gamma', type=float, default=0.95, help='discount factor')
parser.add_argument(
'--batch_size',
type=int,
default=1024,
help='number of episodes to optimize at the same time')
parser.add_argument('--tau', type=int, default=0.01, help='soft update')
# auto save model, optional restore model
parser.add_argument(
'--show', action='store_true', default=False, help='display or not')
parser.add_argument(
'--restore',
action='store_true',
default=False,
help='restore or not, must have model_dir')
parser.add_argument(
'--model_dir',
type=str,
default='./model',
help='directory for saving model')
args = parser.parse_args()
train_agent()
......@@ -14,6 +14,7 @@
from parl.algorithms.fluid.a3c import *
from parl.algorithms.fluid.ddpg import *
from parl.algorithms.fluid.maddpg import *
from parl.algorithms.fluid.dqn import *
from parl.algorithms.fluid.ddqn import *
from parl.algorithms.fluid.policy_gradient import *
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
from parl.core.fluid import layers
from copy import deepcopy
from paddle import fluid
from parl.core.fluid.algorithm import Algorithm
__all__ = ['MADDPG']
from gym import spaces
from parl.core.fluid.policy_distribution import SoftCategoricalDistribution
from parl.core.fluid.policy_distribution import SoftMultiCategoricalDistribution
def SoftPDistribution(logits, act_space):
if (isinstance(act_space, spaces.Discrete)):
return SoftCategoricalDistribution(logits)
# is instance of multiagent.multi_discrete.MultiDiscrete
elif (hasattr(act_space, 'num_discrete_space')):
return SoftMultiCategoricalDistribution(logits, act_space.low,
act_space.high)
else:
raise NotImplementedError
class MADDPG(Algorithm):
def __init__(self,
model,
agent_index=None,
act_space=None,
gamma=None,
tau=None,
lr=None):
""" MADDPG algorithm
Args:
model (parl.Model): forward network of actor and critic.
The function get_actor_params() of model should be implemented.
agent_index: index of agent, in multiagent env
act_space: action_space, gym space
gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model
lr (float): learning rate
"""
assert isinstance(agent_index, int)
assert isinstance(act_space, list)
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(lr, float)
self.agent_index = agent_index
self.act_space = act_space
self.gamma = gamma
self.tau = tau
self.lr = lr
self.model = model
self.target_model = deepcopy(model)
def predict(self, obs):
""" input:
obs: observation, shape([B] + shape of obs_n[agent_index])
output:
act: action, shape([B] + shape of act_n[agent_index])
"""
this_policy = self.model.policy(obs)
this_action = SoftPDistribution(
logits=this_policy,
act_space=self.act_space[self.agent_index]).sample()
return this_action
def predict_next(self, obs):
""" input: observation, shape([B] + shape of obs_n[agent_index])
output: action, shape([B] + shape of act_n[agent_index])
"""
next_policy = self.target_model.policy(obs)
next_action = SoftPDistribution(
logits=next_policy,
act_space=self.act_space[self.agent_index]).sample()
return next_action
def Q(self, obs_n, act_n):
""" input:
obs_n: all agents' observation, shape([B] + shape of obs_n)
output:
act_n: all agents' action, shape([B] + shape of act_n)
"""
return self.model.value(obs_n, act_n)
def Q_next(self, obs_n, act_n):
""" input:
obs_n: all agents' observation, shape([B] + shape of obs_n)
output:
act_n: all agents' action, shape([B] + shape of act_n)
"""
return self.target_model.value(obs_n, act_n)
def learn(self, obs_n, act_n, target_q):
""" update actor and critic model with MADDPG algorithm
"""
actor_cost = self._actor_learn(obs_n, act_n)
critic_cost = self._critic_learn(obs_n, act_n, target_q)
return critic_cost
def _actor_learn(self, obs_n, act_n):
i = self.agent_index
this_policy = self.model.policy(obs_n[i])
sample_this_action = SoftPDistribution(
logits=this_policy,
act_space=self.act_space[self.agent_index]).sample()
action_input_n = act_n + []
action_input_n[i] = sample_this_action
eval_q = self.Q(obs_n, action_input_n)
act_cost = layers.reduce_mean(-1.0 * eval_q)
act_reg = layers.reduce_mean(layers.square(this_policy))
cost = act_cost + act_reg * 1e-3
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
param_list=self.model.get_actor_params())
optimizer = fluid.optimizer.AdamOptimizer(self.lr)
optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
return cost
def _critic_learn(self, obs_n, act_n, target_q):
pred_q = self.Q(obs_n, act_n)
cost = layers.reduce_mean(layers.square_error_cost(pred_q, target_q))
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByNorm(clip_norm=0.5),
param_list=self.model.get_critic_params())
optimizer = fluid.optimizer.AdamOptimizer(self.lr)
optimizer.minimize(cost, parameter_list=self.model.get_critic_params())
return cost
def sync_target(self, decay=None):
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(self.target_model, decay=decay)
......@@ -14,7 +14,10 @@
from parl.core.fluid import layers
__all__ = ['PolicyDistribution', 'CategoricalDistribution']
__all__ = [
'PolicyDistribution', 'CategoricalDistribution',
'SoftCategoricalDistribution', 'SoftMultiCategoricalDistribution'
]
class PolicyDistribution(object):
......@@ -79,7 +82,6 @@ class CategoricalDistribution(PolicyDistribution):
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
assert len(actions.shape) == 1
logits = self.logits - layers.reduce_max(self.logits, dim=1)
......@@ -122,3 +124,88 @@ class CategoricalDistribution(PolicyDistribution):
(logits - layers.log(z) - other_logits + layers.log(other_z)),
dim=1)
return kl
class SoftCategoricalDistribution(CategoricalDistribution):
"""Categorical distribution with noise for discrete action spaces"""
def __init__(self, logits):
"""
Args:
logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
"""
self.logits = logits
super(SoftCategoricalDistribution, self).__init__(logits)
def sample(self):
"""
Returns:
sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
with noise to keep the target close to the original action.
"""
eps = 1e-4
logits_shape = layers.cast(layers.shape(self.logits), dtype='int64')
uniform = layers.uniform_random(logits_shape, min=eps, max=1.0 - eps)
soft_uniform = layers.log(-1.0 * layers.log(uniform))
return layers.softmax(self.logits - soft_uniform, axis=-1)
class SoftMultiCategoricalDistribution(PolicyDistribution):
"""Categorical distribution with noise for MultiDiscrete action spaces."""
def __init__(self, logits, low, high):
"""
Args:
logits: A float32 tensor with shape [BATCH_SIZE, LEN_MultiDiscrete, NUM_ACTIONS] of unnormalized policy logits
low: lower bounds of sample action
high: Upper bounds of action
"""
self.logits = logits
self.low = low
self.high = high
self.categoricals = list(
map(
SoftCategoricalDistribution,
layers.split(
input=logits,
num_or_sections=list(high - low + 1),
dim=len(logits.shape) - 1)))
def sample(self):
"""
Returns:
sample_action: An int64 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
with noise to keep the target close to the original action.
"""
cate_list = []
for i in range(len(self.categoricals)):
cate_list.append(self.low[i] + self.categoricals[i].sample())
return layers.concat(cate_list, axis=-1)
def layers_add_n(self, input_list):
"""
Adds all input tensors element-wise, can replace tf.add_n
"""
assert len(input_list) >= 1
res = input_list[0]
for i in range(1, len(input_list)):
res = layers.elementwise_add(res, input_list[i])
return res
def entropy(self):
"""
Returns:
entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
"""
return self.layers_add_n([p.entropy() for p in self.categoricals])
def kl(self, other):
"""
Args:
other: object of SoftCategoricalDistribution
Returns:
kl: A float32 tensor with shape [BATCH_SIZE]
"""
return self.layers_add_n(
[p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from gym import spaces
from multiagent.multi_discrete import MultiDiscrete
from multiagent.environment import MultiAgentEnv
import multiagent.scenarios as scenarios
class MAenv(MultiAgentEnv):
""" multiagent environment warppers for maddpg
"""
def __init__(self, scenario_name):
# load scenario from script
scenario = scenarios.load(scenario_name + ".py").Scenario()
# create world
world = scenario.make_world()
# initial multiagent environment
super().__init__(world, scenario.reset_world, scenario.reward,
scenario.observation)
self.obs_shape_n = [
self.get_shape(self.observation_space[i]) for i in range(self.n)
]
self.act_shape_n = [
self.get_shape(self.action_space[i]) for i in range(self.n)
]
def get_shape(self, input_space):
"""
Args:
input_space: environment space
Returns:
space shape
"""
if (isinstance(input_space, spaces.Box)):
if (len(input_space.shape) == 1):
return input_space.shape[0]
else:
return input_space.shape
elif (isinstance(input_space, spaces.Discrete)):
return input_space.n
elif (isinstance(input_space, MultiDiscrete)):
return sum(input_space.high - input_space.low + 1)
else:
print('[Error] shape is {}, not Box or Discrete or MultiDiscrete'.
format(input_space.shape))
raise NotImplementedError
......@@ -34,10 +34,8 @@ class ReplayMemory(object):
self._curr_pos = 0
def sample_batch(self, batch_size):
# index mapping to avoid sampling saving example
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
batch_idx = (self._curr_pos + 300 + batch_idx) % self._curr_size
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
......@@ -46,6 +44,19 @@ class ReplayMemory(object):
terminal = self.terminal[batch_idx]
return obs, action, reward, next_obs, terminal
def make_index(self, batch_size):
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
return batch_idx
def sample_batch_by_index(self, batch_idx):
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
action = self.action[batch_idx]
next_obs = self.next_obs[batch_idx]
terminal = self.terminal[batch_idx]
return obs, action, reward, next_obs, terminal
def append(self, obs, act, reward, next_obs, terminal):
if self._curr_size < self.max_size:
self._curr_size += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册