diff --git a/examples/TD3/.benchmark/merge.png b/examples/TD3/.benchmark/merge.png new file mode 100644 index 0000000000000000000000000000000000000000..5b3f82c2467bb275670321aef0887a0fd27d0a75 Binary files /dev/null and b/examples/TD3/.benchmark/merge.png differ diff --git a/examples/TD3/README.md b/examples/TD3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..941ce570d115c5a95ed11c105b27d634440ce5a7 --- /dev/null +++ b/examples/TD3/README.md @@ -0,0 +1,33 @@ +## Reproduce TD3 with PARL +Based on PARL, the TD3 algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks. + +Include following approaches: ++ Clipped Double Q-learning ++ Target Networks and Delayed Policy Update ++ Target Policy Smoothing Regularization + +> TD3 in +[Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477) + +### Mujoco games introduction +Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games. + +### Benchmark result + +Performance + +## How to use +### Dependencies: ++ python3.5+ ++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym ++ mujoco-py>=1.50.1.0 + +### Start Training: +``` +# To train an agent for HalfCheetah-v2 game +python train.py + +# To train for different game and different loss type +# python train.py --env [ENV_NAME] diff --git a/examples/TD3/mujoco_agent.py b/examples/TD3/mujoco_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e35ff14814519ca275ed81dbebfc2b04cf7057ae --- /dev/null +++ b/examples/TD3/mujoco_agent.py @@ -0,0 +1,89 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from parl import layers +from paddle import fluid + + +class MujocoAgent(parl.Agent): + def __init__(self, algorithm, obs_dim, act_dim): + assert isinstance(obs_dim, int) + assert isinstance(act_dim, int) + self.obs_dim = obs_dim + self.act_dim = act_dim + super(MujocoAgent, self).__init__(algorithm) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target(decay=0) + self.learn_it = 0 + self.policy_freq = self.alg.policy_freq + + def build_program(self): + self.pred_program = fluid.Program() + self.actor_learn_program = fluid.Program() + self.critic_learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.actor_learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.actor_cost = self.alg.actor_learn(obs) + + with fluid.program_guard(self.critic_learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data( + name='act', shape=[self.act_dim], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', shape=[self.obs_dim], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.critic_cost = self.alg.critic_learn(obs, act, reward, + next_obs, terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + self.learn_it += 1 + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.fluid_executor.run( + self.critic_learn_program, + feed=feed, + fetch_list=[self.critic_cost])[0] + + actor_cost = None + if self.learn_it % self.policy_freq == 0: + actor_cost = self.fluid_executor.run( + self.actor_learn_program, + feed={'obs': obs}, + fetch_list=[self.actor_cost])[0] + self.alg.sync_target() + return actor_cost, critic_cost diff --git a/examples/TD3/mujoco_model.py b/examples/TD3/mujoco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d39a23de1994a2d031ada580d58459d2b77b0f67 --- /dev/null +++ b/examples/TD3/mujoco_model.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl +from parl import layers + + +class MujocoModel(parl.Model): + def __init__(self, act_dim, max_action): + self.actor_model = ActorModel(act_dim, max_action) + self.critic_model = CriticModel() + + def policy(self, obs): + return self.actor_model.policy(obs) + + def value(self, obs, act): + return self.critic_model.value(obs, act) + + def Q1(self, obs, act): + return self.critic_model.Q1(obs, act) + + def get_actor_params(self): + return self.actor_model.parameters() + + +class ActorModel(parl.Model): + def __init__(self, act_dim, max_action): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=act_dim, act='tanh') + + self.max_action = max_action + + def policy(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.fc3(hid2) + means = means * self.max_action + return means + + +class CriticModel(parl.Model): + def __init__(self): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=1, act=None) + + self.fc4 = layers.fc(size=hid1_size, act='relu') + self.fc5 = layers.fc(size=hid2_size, act='relu') + self.fc6 = layers.fc(size=1, act=None) + + def value(self, obs, act): + hid1 = self.fc1(obs) + concat1 = layers.concat([hid1, act], axis=1) + Q1 = self.fc2(concat1) + Q1 = self.fc3(Q1) + Q1 = layers.squeeze(Q1, axes=[1]) + + hid2 = self.fc4(obs) + concat2 = layers.concat([hid2, act], axis=1) + Q2 = self.fc5(concat2) + Q2 = self.fc6(Q2) + Q2 = layers.squeeze(Q2, axes=[1]) + return Q1, Q2 + + def Q1(self, obs, act): + hid1 = self.fc1(obs) + concat1 = layers.concat([hid1, act], axis=1) + Q1 = self.fc2(concat1) + Q1 = self.fc3(Q1) + Q1 = layers.squeeze(Q1, axes=[1]) + + return Q1 diff --git a/examples/TD3/train.py b/examples/TD3/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb74d9c01ab73dcb8cb20385b36262cb7c4aeba --- /dev/null +++ b/examples/TD3/train.py @@ -0,0 +1,150 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym +import numpy as np +import time +import parl +from mujoco_agent import MujocoAgent +from mujoco_model import MujocoModel +from parl.utils import logger, tensorboard, action_mapping, ReplayMemory + +MAX_EPISODES = 5000 +ACTOR_LR = 3e-4 +CRITIC_LR = 3e-4 +GAMMA = 0.99 +TAU = 0.005 +MEMORY_SIZE = int(1e6) +WARMUP_SIZE = 1e4 +BATCH_SIZE = 256 +ENV_SEED = 1 +EXPL_NOISE = 0.1 # Std of Gaussian exploration noise + + +def run_train_episode(env, agent, rpm): + obs = env.reset() + total_reward = 0 + steps = 0 + max_action = float(env.action_space.high[0]) + while True: + steps += 1 + batch_obs = np.expand_dims(obs, axis=0) + + if rpm.size() < WARMUP_SIZE: + action = env.action_space.sample() + else: + action = agent.predict(batch_obs.astype('float32')) + action = np.squeeze(action) + + # Add exploration noise, and clip to [-max_action, max_action] + action = np.clip( + np.random.normal(action, EXPL_NOISE * max_action), -max_action, + max_action) + + next_obs, reward, done, info = env.step(action) + + rpm.append(obs, action, reward, next_obs, done) + + if rpm.size() > WARMUP_SIZE: + batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch( + BATCH_SIZE) + agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, + batch_terminal) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward, steps + + +def run_evaluate_episode(env, agent): + obs = env.reset() + total_reward = 0 + while True: + batch_obs = np.expand_dims(obs, axis=0) + action = agent.predict(batch_obs.astype('float32')) + action = np.squeeze(action) + action = action_mapping(action, env.action_space.low[0], + env.action_space.high[0]) + + next_obs, reward, done, info = env.step(action) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward + + +def main(): + env = gym.make(args.env) + env.seed(ENV_SEED) + + obs_dim = env.observation_space.shape[0] + act_dim = env.action_space.shape[0] + max_action = float(env.action_space.high[0]) + + model = MujocoModel(act_dim, max_action) + algorithm = parl.algorithms.TD3( + model, + max_action=max_action, + gamma=GAMMA, + tau=TAU, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + agent = MujocoAgent(algorithm, obs_dim, act_dim) + + rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim) + + test_flag = 0 + total_steps = 0 + while total_steps < args.train_total_steps: + train_reward, steps = run_train_episode(env, agent, rpm) + total_steps += steps + logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward)) + tensorboard.add_scalar('train/episode_reward', train_reward, + total_steps) + + if total_steps // args.test_every_steps >= test_flag: + while total_steps // args.test_every_steps >= test_flag: + test_flag += 1 + evaluate_reward = run_evaluate_episode(env, agent) + logger.info('Steps {}, Evaluate reward: {}'.format( + total_steps, evaluate_reward)) + tensorboard.add_scalar('eval/episode_reward', evaluate_reward, + total_steps) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--env', help='Mujoco environment name', default='HalfCheetah-v2') + parser.add_argument( + '--train_total_steps', + type=int, + default=int(1e7), + help='maximum training steps') + parser.add_argument( + '--test_every_steps', + type=int, + default=int(1e4), + help='the step interval between two consecutive evaluations') + + args = parser.parse_args() + + main() diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py index bcf8c17c938133005a0da6a298076f2d157614bc..468be6f249fe82a52760d7872517b493a6ed5f99 100644 --- a/parl/algorithms/fluid/__init__.py +++ b/parl/algorithms/fluid/__init__.py @@ -18,4 +18,5 @@ from parl.algorithms.fluid.dqn import * from parl.algorithms.fluid.ddqn import * from parl.algorithms.fluid.policy_gradient import * from parl.algorithms.fluid.ppo import * +from parl.algorithms.fluid.td3 import * from parl.algorithms.fluid.impala.impala import * diff --git a/parl/algorithms/fluid/td3.py b/parl/algorithms/fluid/td3.py new file mode 100644 index 0000000000000000000000000000000000000000..d2efde4a36c9c8774909a152197740f1f447c15b --- /dev/null +++ b/parl/algorithms/fluid/td3.py @@ -0,0 +1,94 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.core.fluid import layers +from copy import deepcopy +from paddle import fluid +from parl.core.fluid.algorithm import Algorithm + +__all__ = ['TD3'] + + +class TD3(Algorithm): + def __init__( + self, + model, + max_action, + gamma=None, + tau=None, + actor_lr=None, + critic_lr=None, + policy_noise=0.2, # Noise added to target policy during critic update + noise_clip=0.5, # Range to clip target policy noise + policy_freq=2): # Frequency of delayed policy updates + assert isinstance(gamma, float) + assert isinstance(tau, float) + assert isinstance(actor_lr, float) + assert isinstance(critic_lr, float) + self.max_action = max_action + self.gamma = gamma + self.tau = tau + self.actor_lr = actor_lr + self.critic_lr = critic_lr + self.policy_noise = policy_noise + self.noise_clip = noise_clip + self.policy_freq = policy_freq + + self.model = model + self.target_model = deepcopy(model) + + def predict(self, obs): + """ use actor model of self.model to predict the action + """ + return self.model.policy(obs) + + def actor_learn(self, obs): + action = self.model.policy(obs) + Q = self.model.Q1(obs, action) + cost = layers.reduce_mean(-1.0 * Q) + optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr) + optimizer.minimize(cost, parameter_list=self.model.get_actor_params()) + return cost + + def critic_learn(self, obs, action, reward, next_obs, terminal): + noise = layers.gaussian_random_batch_size_like( + action, shape=[-1, action.shape[1]]) + noise = layers.clip( + noise * self.policy_noise, + min=-self.noise_clip, + max=self.noise_clip) + next_action = self.target_model.policy(next_obs) + noise + next_action = layers.clip(next_action, -self.max_action, + self.max_action) + + next_Q1, next_Q2 = self.target_model.value(next_obs, next_action) + next_Q = layers.elementwise_min(next_Q1, next_Q2) + + terminal = layers.cast(terminal, dtype='float32') + target_Q = reward + (1.0 - terminal) * self.gamma * next_Q + target_Q.stop_gradient = True + + current_Q1, current_Q2 = self.model.value(obs, action) + cost = layers.square_error_cost(current_Q1, + target_Q) + layers.square_error_cost( + current_Q2, target_Q) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr) + optimizer.minimize(cost) + return cost + + def sync_target(self, decay=None): + if decay is None: + decay = 1.0 - self.tau + self.model.sync_weights_to(self.target_model, decay=decay)