diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt index 0b2d8e3ebe2cfd15cb5b2bc63cda5c9add38dd1a..354e3632e02ce8e678df2024a6d16657281c1a0e 100644 --- a/.teamcity/requirements.txt +++ b/.teamcity/requirements.txt @@ -1,5 +1,5 @@ # requirements for unittest -paddlepaddle-gpu==1.5.1.post97 +paddlepaddle-gpu==1.6.1.post97 gym details parameterized diff --git a/README.cn.md b/README.cn.md index 8a6c9fb4fe423ed5f12bd58a264a66f776ca30f1..09f1df56a90bcc36dd0971038dfd15de501034ec 100644 --- a/README.cn.md +++ b/README.cn.md @@ -76,8 +76,10 @@ pip install parl - [PPO](examples/PPO/) - [IMPALA](examples/IMPALA/) - [A2C](examples/A2C/) -- [GA3C](examples/GA3C/) +- [TD3](examples/TD3/) +- [SAC](examples/SAC/) - [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) +- [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/) NeurlIPS2018 Half-Cheetah Breakout
diff --git a/README.md b/README.md index 29c87fa1fb228446e6145aea40da32b1c34efd6b..5245c349951b28a1a6c74e0a15cfc89d22edfaf3 100644 --- a/README.md +++ b/README.md @@ -79,8 +79,10 @@ pip install parl - [PPO](examples/PPO/) - [IMPALA](examples/IMPALA/) - [A2C](examples/A2C/) -- [GA3C](examples/GA3C/) +- [TD3](examples/TD3/) +- [SAC](examples/SAC/) - [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) +- [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/) NeurlIPS2018 Half-Cheetah Breakout
diff --git a/examples/SAC/.benchmark/merge.png b/examples/SAC/.benchmark/merge.png new file mode 100644 index 0000000000000000000000000000000000000000..95a74a5fd047b11fecf4d176ef3a0d688eb73850 Binary files /dev/null and b/examples/SAC/.benchmark/merge.png differ diff --git a/examples/SAC/README.md b/examples/SAC/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9a02209d663556900842317aa5f1ab987e14af3 --- /dev/null +++ b/examples/SAC/README.md @@ -0,0 +1,32 @@ +## Reproduce SAC with PARL +Based on PARL, the SAC algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks. + +Include following approaches: ++ DDPG Style with Stochastic Policy ++ Maximum Entropy + +> SAC in +[Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor](https://arxiv.org/abs/1801.01290) + +### Mujoco games introduction +Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games. + +### Benchmark result + +Performance + +## How to use +### Dependencies: ++ python3.5+ ++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym ++ mujoco-py>=1.50.1.0 + +### Start Training: +``` +# To train an agent for HalfCheetah-v2 game +python train.py + +# To train for different games +# python train.py --env [ENV_NAME] diff --git a/examples/SAC/mujoco_agent.py b/examples/SAC/mujoco_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c1ccec4bc68e47b51bb41c63b120ec79d3ffd9 --- /dev/null +++ b/examples/SAC/mujoco_agent.py @@ -0,0 +1,87 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from parl import layers +from paddle import fluid + + +class MujocoAgent(parl.Agent): + def __init__(self, algorithm, obs_dim, act_dim): + assert isinstance(obs_dim, int) + assert isinstance(act_dim, int) + self.obs_dim = obs_dim + self.act_dim = act_dim + super(MujocoAgent, self).__init__(algorithm) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.sample_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.sample_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.sample_act, _ = self.alg.sample(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data( + name='act', shape=[self.act_dim], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', shape=[self.obs_dim], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.critic_cost, self.actor_cost = self.alg.learn( + obs, act, reward, next_obs, terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def sample(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.sample_program, + feed={'obs': obs}, + fetch_list=[self.sample_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + [critic_cost, actor_cost] = self.fluid_executor.run( + self.learn_program, + feed=feed, + fetch_list=[self.critic_cost, self.actor_cost]) + self.alg.sync_target() + return critic_cost[0], actor_cost[0] diff --git a/examples/SAC/mujoco_model.py b/examples/SAC/mujoco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d9ee39df3e3ce2e2718ab436638836a49fe7b2 --- /dev/null +++ b/examples/SAC/mujoco_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl +from parl import layers + +LOG_SIG_MAX = 2.0 +LOG_SIG_MIN = -20.0 + + +class ActorModel(parl.Model): + def __init__(self, act_dim): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.mean_linear = layers.fc(size=act_dim) + self.log_std_linear = layers.fc(size=act_dim) + + def policy(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.mean_linear(hid2) + log_std = self.log_std_linear(hid2) + log_std = layers.clip(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) + + return means, log_std + + +class CriticModel(parl.Model): + def __init__(self): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=1, act=None) + + self.fc4 = layers.fc(size=hid1_size, act='relu') + self.fc5 = layers.fc(size=hid2_size, act='relu') + self.fc6 = layers.fc(size=1, act=None) + + def value(self, obs, act): + hid1 = self.fc1(obs) + concat1 = layers.concat([hid1, act], axis=1) + Q1 = self.fc2(concat1) + Q1 = self.fc3(Q1) + Q1 = layers.squeeze(Q1, axes=[1]) + + hid2 = self.fc4(obs) + concat2 = layers.concat([hid2, act], axis=1) + Q2 = self.fc5(concat2) + Q2 = self.fc6(Q2) + Q2 = layers.squeeze(Q2, axes=[1]) + + return Q1, Q2 diff --git a/examples/SAC/train.py b/examples/SAC/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a88260245880a39738f931573dd0b183487722df --- /dev/null +++ b/examples/SAC/train.py @@ -0,0 +1,150 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Refer to https://github.com/pranz24/pytorch-soft-actor-critic + +import argparse +import gym +import numpy as np +import time +import parl +from mujoco_agent import MujocoAgent +from mujoco_model import ActorModel, CriticModel +from parl.utils import logger, tensorboard, action_mapping, ReplayMemory + +ACTOR_LR = 1e-3 +CRITIC_LR = 1e-3 +GAMMA = 0.99 +TAU = 0.005 +MEMORY_SIZE = int(1e6) +WARMUP_SIZE = 1e4 +BATCH_SIZE = 256 +ENV_SEED = 1 + + +def run_train_episode(env, agent, rpm): + obs = env.reset() + total_reward = 0 + steps = 0 + while True: + steps += 1 + batch_obs = np.expand_dims(obs, axis=0) + + if rpm.size() < WARMUP_SIZE: + action = env.action_space.sample() + else: + action = agent.sample(batch_obs.astype('float32')) + action = np.squeeze(action) + + next_obs, reward, done, info = env.step(action) + + rpm.append(obs, action, reward, next_obs, done) + + if rpm.size() > WARMUP_SIZE: + batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch( + BATCH_SIZE) + agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, + batch_terminal) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward, steps + + +def run_evaluate_episode(env, agent): + obs = env.reset() + total_reward = 0 + while True: + batch_obs = np.expand_dims(obs, axis=0) + action = agent.predict(batch_obs.astype('float32')) + action = np.squeeze(action) + + next_obs, reward, done, info = env.step(action) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward + + +def main(): + env = gym.make(args.env) + env.seed(ENV_SEED) + + obs_dim = env.observation_space.shape[0] + act_dim = env.action_space.shape[0] + max_action = float(env.action_space.high[0]) + + actor = ActorModel(act_dim) + critic = CriticModel() + algorithm = parl.algorithms.SAC( + actor, + critic, + max_action=max_action, + gamma=GAMMA, + tau=TAU, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + agent = MujocoAgent(algorithm, obs_dim, act_dim) + + rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim) + + test_flag = 0 + total_steps = 0 + while total_steps < args.train_total_steps: + train_reward, steps = run_train_episode(env, agent, rpm) + total_steps += steps + logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward)) + tensorboard.add_scalar('train/episode_reward', train_reward, + total_steps) + + if total_steps // args.test_every_steps >= test_flag: + while total_steps // args.test_every_steps >= test_flag: + test_flag += 1 + evaluate_reward = run_evaluate_episode(env, agent) + logger.info('Steps {}, Evaluate reward: {}'.format( + total_steps, evaluate_reward)) + tensorboard.add_scalar('eval/episode_reward', evaluate_reward, + total_steps) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--env', help='Mujoco environment name', default='HalfCheetah-v2') + parser.add_argument( + '--train_total_steps', + type=int, + default=int(1e6), + help='maximum training steps') + parser.add_argument( + '--test_every_steps', + type=int, + default=int(1e4), + help='the step interval between two consecutive evaluations') + parser.add_argument( + '--alpha', + type=float, + default=0.2, + help='Temperature parameter α determines the relative importance of the \ + entropy term against the reward (default: 0.2)') + + args = parser.parse_args() + + main() diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py index 468be6f249fe82a52760d7872517b493a6ed5f99..3f005ac623ed5cf7fc95b2f9242c962dbfbf3302 100644 --- a/parl/algorithms/fluid/__init__.py +++ b/parl/algorithms/fluid/__init__.py @@ -19,4 +19,5 @@ from parl.algorithms.fluid.ddqn import * from parl.algorithms.fluid.policy_gradient import * from parl.algorithms.fluid.ppo import * from parl.algorithms.fluid.td3 import * +from parl.algorithms.fluid.sac import * from parl.algorithms.fluid.impala.impala import * diff --git a/parl/algorithms/fluid/sac.py b/parl/algorithms/fluid/sac.py new file mode 100644 index 0000000000000000000000000000000000000000..cec92c98568905af7bce64252e9f3ff0531da039 --- /dev/null +++ b/parl/algorithms/fluid/sac.py @@ -0,0 +1,127 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.core.fluid import layers +from copy import deepcopy +import numpy as np +from paddle import fluid +from paddle.fluid.layers import Normal +from parl.core.fluid.algorithm import Algorithm + +epsilon = 1e-6 + +__all__ = ['SAC'] + + +class SAC(Algorithm): + def __init__(self, + actor, + critic, + max_action, + alpha=0.2, + gamma=None, + tau=None, + actor_lr=None, + critic_lr=None): + """ SAC algorithm + + Args: + actor (parl.Model): forward network of actor. + critic (patl.Model): forward network of the critic. + max_action (float): the largest value that an action can be, env.action_space.high[0] + alpha (float): Temperature parameter determines the relative importance of the entropy against the reward + gamma (float): discounted factor for reward computation. + tau (float): decay coefficient when updating the weights of self.target_model with self.model + actor_lr (float): learning rate of the actor model + critic_lr (float): learning rate of the critic model + """ + assert isinstance(gamma, float) + assert isinstance(tau, float) + assert isinstance(actor_lr, float) + assert isinstance(critic_lr, float) + assert isinstance(alpha, float) + self.max_action = max_action + self.gamma = gamma + self.tau = tau + self.actor_lr = actor_lr + self.critic_lr = critic_lr + + self.alpha = alpha + + self.actor = actor + self.critic = critic + self.target_critic = deepcopy(critic) + + def predict(self, obs): + """ use actor model of self.policy to predict the action + """ + mean, _ = self.actor.policy(obs) + mean = layers.tanh(mean) * self.max_action + return mean + + def sample(self, obs): + mean, log_std = self.actor.policy(obs) + std = layers.exp(log_std) + normal = Normal(mean, std) + x_t = normal.sample([1])[0] + y_t = layers.tanh(x_t) + action = y_t * self.max_action + log_prob = normal.log_prob(x_t) + log_prob -= layers.log(self.max_action * (1 - layers.pow(y_t, 2)) + + epsilon) + log_prob = layers.reduce_sum(log_prob, dim=1, keep_dim=True) + log_prob = layers.squeeze(log_prob, axes=[1]) + return action, log_prob + + def learn(self, obs, action, reward, next_obs, terminal): + actor_cost = self.actor_learn(obs) + critic_cost = self.critic_learn(obs, action, reward, next_obs, + terminal) + return critic_cost, actor_cost + + def actor_learn(self, obs): + action, log_pi = self.sample(obs) + qf1_pi, qf2_pi = self.critic.value(obs, action) + min_qf_pi = layers.elementwise_min(qf1_pi, qf2_pi) + cost = log_pi * self.alpha - min_qf_pi + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr) + optimizer.minimize(cost, parameter_list=self.actor.parameters()) + + return cost + + def critic_learn(self, obs, action, reward, next_obs, terminal): + next_state_action, next_state_log_pi = self.sample(next_obs) + qf1_next_target, qf2_next_target = self.target_critic.value( + next_obs, next_state_action) + min_qf_next_target = layers.elementwise_min( + qf1_next_target, qf2_next_target) - next_state_log_pi * self.alpha + + terminal = layers.cast(terminal, dtype='float32') + target_Q = reward + (1.0 - terminal) * self.gamma * min_qf_next_target + target_Q.stop_gradient = True + + current_Q1, current_Q2 = self.critic.value(obs, action) + cost = layers.square_error_cost(current_Q1, + target_Q) + layers.square_error_cost( + current_Q2, target_Q) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr) + optimizer.minimize(cost) + return cost + + def sync_target(self, decay=None): + if decay is None: + decay = 1.0 - self.tau + self.critic.sync_weights_to(self.target_critic, decay=decay) diff --git a/parl/remote/scripts.py b/parl/remote/scripts.py index fd76419684bce3c8fad63a8099ea86dca8c7f88b..71677d692878eef63f65b0ff1054cb6233b0d7a5 100644 --- a/parl/remote/scripts.py +++ b/parl/remote/scripts.py @@ -219,11 +219,14 @@ def start_worker(address, cpu_num): @click.command("stop", help="Exit the cluster.") def stop(): - command = ("pkill -f remote/start.py") + command = ( + "ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9") subprocess.call([command], shell=True) - command = ("pkill -f remote/job.py") + command = ( + "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9") subprocess.call([command], shell=True) - command = ("pkill -f remote/monitor.py") + command = ( + "ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9") subprocess.call([command], shell=True) diff --git a/parl/remote/tests/reset_job_test_alone.py b/parl/remote/tests/reset_job_test_alone.py index 7ca5969658548dfc977a42ce0f5350f90b7a4ea5..81cc2fe77a102521c0dc0633d215821a2a5d991c 100644 --- a/parl/remote/tests/reset_job_test_alone.py +++ b/parl/remote/tests/reset_job_test_alone.py @@ -70,7 +70,8 @@ class TestJobAlone(unittest.TestCase): time.sleep(1) self.assertEqual(master.cpu_num, 4) print("We are going to kill all the jobs.") - command = ("pkill -f remote/job.py") + command = ( + "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9") subprocess.call([command], shell=True) parl.connect('localhost:1334') actor = Actor()