diff --git a/README.md b/README.md index d714a2bb506b2f013dadf6e98b0a9c0d806d0d1f..5285eddcd43b2dc2ddcd746e7b36724015ff592a 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,6 @@ pip install --upgrade git+https://github.com/PaddlePaddle/PARL.git # Examples - [QuickStart](examples/QuickStart/) - [DQN](examples/DQN/) -- DDPG +- [DDPG](examples/DDPG/) - PPO - [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) diff --git a/examples/DDPG/README.md b/examples/DDPG/README.md new file mode 100644 index 0000000000000000000000000000000000000000..61d32403c4f54712478b462f084a4e75ded606e7 --- /dev/null +++ b/examples/DDPG/README.md @@ -0,0 +1,26 @@ +## Reproduce DDPG with PARL +Based on PARL, the DDPG model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Mujoco game. + ++ DDPG in +[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971) + +### Mujoco games introduction +Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco game. + + +## How to use +### Dependencies: ++ python2.7 or python3.5+ ++ [PARL](https://github.com/PaddlePaddle/PARL) ++ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle) ++ gym ++ tqdm ++ mujoco-py>=1.50.1.0 + +### Start Training: +``` +# To train an agent for HalfCheetah-v2 game +python train.py + +# To train for other game +# python train.py --env [ENV_NAME] diff --git a/examples/DDPG/mujoco_agent.py b/examples/DDPG/mujoco_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9070848deff0e3a57aadff0082424dde86b941 --- /dev/null +++ b/examples/DDPG/mujoco_agent.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl.layers as layers +from paddle import fluid +from parl.framework.agent_base import Agent + + +class MujocoAgent(Agent): + def __init__(self, algorithm, obs_dim, act_dim): + self.obs_dim = obs_dim + self.act_dim = act_dim + super(MujocoAgent, self).__init__(algorithm) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target(gpu_id=self.gpu_id, decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.pred_act = self.alg.define_predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data( + name='act', shape=[self.act_dim], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', shape=[self.obs_dim], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + _, self.critic_cost = self.alg.define_learn( + obs, act, reward, next_obs, terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + self.pred_program, feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0] + self.alg.sync_target(gpu_id=self.gpu_id) + return critic_cost diff --git a/examples/DDPG/mujoco_model.py b/examples/DDPG/mujoco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..991842d6ebb1d7e6150f61cadf35b55c931e38f0 --- /dev/null +++ b/examples/DDPG/mujoco_model.py @@ -0,0 +1,68 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.model_base import Model + + +class MujocoModel(Model): + def __init__(self, act_dim, act_bound): + self.actor_model = ActorModel(act_dim, act_bound) + self.critic_model = CriticModel() + + def policy(self, obs): + return self.actor_model.policy(obs) + + def value(self, obs, act): + return self.critic_model.value(obs, act) + + def get_actor_params(self): + return self.actor_model.parameter_names + + +class ActorModel(Model): + def __init__(self, act_dim, act_bound): + self.act_bound = act_bound + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=act_dim, act='tanh') + + def policy(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.fc3(hid2) + means = means * self.act_bound + return means + + +class CriticModel(Model): + def __init__(self): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=1, act=None) + + def value(self, obs, act): + hid1 = self.fc1(obs) + concat = layers.concat([hid1, act], axis=1) + hid2 = self.fc2(concat) + Q = self.fc3(hid2) + Q = layers.squeeze(Q, axes=[1]) + return Q diff --git a/examples/DDPG/replay_memory.py b/examples/DDPG/replay_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..1a50d9f7adbe40590d6db54dea2893c2ffeddff7 --- /dev/null +++ b/examples/DDPG/replay_memory.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class ReplayMemory(object): + def __init__(self, max_size, obs_dim, act_dim): + self.max_size = max_size + self.obs_memory = np.zeros((max_size, obs_dim), dtype='float32') + self.act_memory = np.zeros((max_size, act_dim), dtype='float32') + self.reward_memory = np.zeros((max_size, ), dtype='float32') + self.next_obs_memory = np.zeros((max_size, obs_dim), dtype='float32') + self.terminal_memory = np.zeros((max_size, ), dtype='bool') + self._curr_size = 0 + self._curr_pos = 0 + + def sample_batch(self, batch_size): + batch_idx = np.random.choice(self._curr_size, size=batch_size) + obs = self.obs_memory[batch_idx, :] + act = self.act_memory[batch_idx, :] + reward = self.reward_memory[batch_idx] + next_obs = self.next_obs_memory[batch_idx, :] + terminal = self.terminal_memory[batch_idx] + return obs, act, reward, next_obs, terminal + + def append(self, obs, act, reward, next_obs, terminal): + if self._curr_size < self.max_size: + self._curr_size += 1 + self.obs_memory[self._curr_pos] = obs + self.act_memory[self._curr_pos] = act + self.reward_memory[self._curr_pos] = reward + self.next_obs_memory[self._curr_pos] = next_obs + self.terminal_memory[self._curr_pos] = terminal + self._curr_pos = (self._curr_pos + 1) % self.max_size + + def size(self): + return self._curr_size diff --git a/examples/DDPG/train.py b/examples/DDPG/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4220e05acd13ceeae6eb47a2cc6fb3ffad49749e --- /dev/null +++ b/examples/DDPG/train.py @@ -0,0 +1,121 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym +import numpy as np +import time +from mujoco_agent import MujocoAgent +from mujoco_model import MujocoModel +from parl.algorithms import DDPG +from parl.utils import logger +from replay_memory import ReplayMemory + +MAX_EPISODES = 5000 +TEST_EVERY_EPISODES = 50 +MAX_STEPS_EACH_EPISODE = 1000 +ACTOR_LR = 1e-4 +CRITIC_LR = 1e-3 +GAMMA = 0.99 +TAU = 0.001 +MEMORY_SIZE = int(1e6) +MIN_LEARN_SIZE = 1e4 +BATCH_SIZE = 128 +REWARD_SCALE = 0.1 +ENV_SEED = 1 + + +def run_train_episode(env, agent, rpm, act_bound): + obs = env.reset() + total_reward = 0 + for j in range(MAX_STEPS_EACH_EPISODE): + batch_obs = np.expand_dims(obs, axis=0) + action = agent.predict(batch_obs.astype('float32')) + action = np.squeeze(action) + + # Add exploration noise + action = np.clip( + np.random.normal(action, act_bound), -act_bound, act_bound) + + next_obs, reward, done, info = env.step(action) + + rpm.append(obs, action, REWARD_SCALE * reward, next_obs, done) + + if rpm.size() > MIN_LEARN_SIZE: + batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch( + BATCH_SIZE) + agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, + batch_terminal) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward + + +def run_evaluate_episode(env, agent): + obs = env.reset() + total_reward = 0 + for j in range(MAX_STEPS_EACH_EPISODE): + batch_obs = np.expand_dims(obs, axis=0) + action = agent.predict(batch_obs.astype('float32')) + action = np.squeeze(action) + + next_obs, reward, done, info = env.step(action) + + obs = next_obs + total_reward += reward + + if done: + break + return total_reward + + +def main(): + env = gym.make(args.env) + env.seed(ENV_SEED) + + obs_dim = env.observation_space.shape[0] + act_dim = env.action_space.shape[0] + act_bound = env.action_space.high[0] + + model = MujocoModel(act_dim, act_bound) + algorithm = DDPG( + model, + hyperparas={ + 'gamma': GAMMA, + 'tau': TAU, + 'actor_lr': ACTOR_LR, + 'critic_lr': CRITIC_LR + }) + agent = MujocoAgent(algorithm, obs_dim, act_dim) + + rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim) + + for i in range(MAX_EPISODES): + train_reward = run_train_episode(env, agent, rpm, act_bound) + logger.info('Episode: {} Reward: {}'.format(i, train_reward)) + if (i + 1) % TEST_EVERY_EPISODES == 0: + evaluate_reward = run_evaluate_episode(env, agent) + logger.info('Evaluate Reward: {}'.format(evaluate_reward)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--env', help='Mujoco environment name', default='HalfCheetah-v2') + args = parser.parse_args() + main() diff --git a/examples/DQN/atari_agent.py b/examples/DQN/atari_agent.py index 2219c62fb2f60777db2aeacd9b3319a3547170a2..ed76ae839534e113c44b8a58505eeaf3d439bf28 100644 --- a/examples/DQN/atari_agent.py +++ b/examples/DQN/atari_agent.py @@ -32,7 +32,7 @@ class AtariAgent(Agent): def build_program(self): self.pred_program = fluid.Program() - self.train_program = fluid.Program() + self.learn_program = fluid.Program() with fluid.program_guard(self.pred_program): obs = layers.data( @@ -41,7 +41,7 @@ class AtariAgent(Agent): dtype='float32') self.value = self.alg.define_predict(obs) - with fluid.program_guard(self.train_program): + with fluid.program_guard(self.learn_program): obs = layers.data( name='obs', shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], @@ -99,5 +99,5 @@ class AtariAgent(Agent): 'terminal': terminal } cost = self.fluid_executor.run( - self.train_program, feed=feed, fetch_list=[self.cost])[0] + self.learn_program, feed=feed, fetch_list=[self.cost])[0] return cost diff --git a/examples/DQN/atari_model.py b/examples/DQN/atari_model.py index 5426ecef176c74dd3a342285c22c1dfbc5477f82..7999db8ab4691be3113db0f7c7e389e863d8d2f9 100644 --- a/examples/DQN/atari_model.py +++ b/examples/DQN/atari_model.py @@ -18,9 +18,7 @@ from parl.framework.model_base import Model class AtariModel(Model): - def __init__(self, img_height, img_width, act_dim): - self.img_height = img_height - self.img_width = img_width + def __init__(self, act_dim): self.act_dim = act_dim self.conv1 = layers.conv2d( diff --git a/examples/DQN/expreplay.py b/examples/DQN/replay_memory.py similarity index 98% rename from examples/DQN/expreplay.py rename to examples/DQN/replay_memory.py index 6e7e2ba25bed33ab538ca6d7d383f588af651975..ea8c6565155ddacae568e901566f9b390ee3a8b8 100644 --- a/examples/DQN/expreplay.py +++ b/examples/DQN/replay_memory.py @@ -88,6 +88,9 @@ class ReplayMemory(object): def __len__(self): return self._curr_size + def size(self): + return self._curr_size + def _assign(self, pos, exp): self.state[pos] = exp.state self.reward[pos] = exp.reward diff --git a/examples/DQN/train.py b/examples/DQN/train.py index a7cb2b32fce145e8f2cab66d860f50cde7969bea..fb91dc48d96b3dd498977dbe2b473bb014049302 100644 --- a/examples/DQN/train.py +++ b/examples/DQN/train.py @@ -13,49 +13,47 @@ # limitations under the License. import argparse -import cv2 import gym import paddle.fluid as fluid import numpy as np import os -from atari import AtariPlayer from atari_agent import AtariAgent from atari_model import AtariModel -from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength from collections import deque from datetime import datetime -from expreplay import ReplayMemory, Experience +from replay_memory import ReplayMemory, Experience from parl.algorithms import DQN from parl.utils import logger from tqdm import tqdm +from utils import get_player MEMORY_SIZE = 1e6 MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 IMAGE_SIZE = (84, 84) CONTEXT_LEN = 4 -ACTION_REPEAT = 4 # aka FRAME_SKIP +FRAME_SKIP = 4 UPDATE_FREQ = 4 GAMMA = 0.99 LEARNING_RATE = 1e-3 * 0.5 -def run_train_episode(agent, env, exp): +def run_train_episode(env, agent, rpm): total_reward = 0 all_cost = [] state = env.reset() step = 0 while True: step += 1 - context = exp.recent_state() + context = rpm.recent_state() context.append(state) context = np.stack(context, axis=0) action = agent.sample(context) next_state, reward, isOver, _ = env.step(action) - exp.append(Experience(state, action, reward, isOver)) + rpm.append(Experience(state, action, reward, isOver)) # start training - if len(exp) > MEMORY_WARMUP_SIZE: + if rpm.size() > MEMORY_WARMUP_SIZE: if step % UPDATE_FREQ == 0: - batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch( + batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( args.batch_size) batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_next_state = batch_all_state[:, 1:, :, :] @@ -71,43 +69,27 @@ def run_train_episode(agent, env, exp): return total_reward, step -def get_player(rom, viz=False, train=False): - env = AtariPlayer( - rom, - frame_skip=ACTION_REPEAT, - viz=viz, - live_lost_as_eoe=train, - max_num_frames=60000) - env = FireResetEnv(env) - env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) - if not train: - # in training, context is taken care of in expreplay buffer - env = FrameStack(env, CONTEXT_LEN) - return env - - -def eval_agent(agent, env): - episode_reward = [] - for _ in tqdm(range(30), desc='eval agent'): - state = env.reset() - total_reward = 0 - step = 0 - while True: - step += 1 - action = agent.predict(state) - state, reward, isOver, info = env.step(action) - total_reward += reward - if isOver: - break - episode_reward.append(total_reward) - eval_reward = np.mean(episode_reward) - return eval_reward - - -def train_agent(): - env = get_player(args.rom, train=True) - test_env = get_player(args.rom) - exp = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) +def run_evaluate_episode(env, agent): + state = env.reset() + total_reward = 0 + while True: + action = agent.predict(state) + state, reward, isOver, info = env.step(action) + total_reward += reward + if isOver: + break + return total_reward + + +def main(): + env = get_player( + args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP) + test_env = get_player( + args.rom, + image_size=IMAGE_SIZE, + frame_skip=FRAME_SKIP, + context_len=CONTEXT_LEN) + rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) action_dim = env.action_space.n hyperparas = { @@ -115,13 +97,13 @@ def train_agent(): 'lr': LEARNING_RATE, 'gamma': GAMMA } - model = AtariModel(IMAGE_SIZE[0], IMAGE_SIZE[1], action_dim) + model = AtariModel(action_dim) algorithm = DQN(model, hyperparas) agent = AtariAgent(algorithm, action_dim) with tqdm(total=MEMORY_WARMUP_SIZE) as pbar: - while len(exp) < MEMORY_WARMUP_SIZE: - total_reward, step = run_train_episode(agent, env, exp) + while rpm.size() < MEMORY_WARMUP_SIZE: + total_reward, step = run_train_episode(env, agent, rpm) pbar.update(step) # train @@ -132,18 +114,23 @@ def train_agent(): max_reward = None while True: # start epoch - total_reward, step = run_train_episode(agent, env, exp) + total_reward, step = run_train_episode(env, agent, rpm) total_step += step pbar.set_description('[train]exploration:{}'.format(agent.exploration)) pbar.update(step) if total_step // args.test_every_steps == test_flag: pbar.write("testing") - eval_reward = eval_agent(agent, test_env) + eval_rewards = [] + for _ in tqdm(range(30), desc='eval agent'): + eval_reward = run_evaluate_episode(test_env, agent) + eval_rewards.append(eval_reward) test_flag += 1 logger.info( "eval_agent done, (steps, eval_reward): ({}, {})".format( - total_step, eval_reward)) + total_step, np.mean(eval_rewards))) + if total_step > 1e8: + break pbar.close() @@ -159,4 +146,4 @@ if __name__ == '__main__': default=100000, help='every steps number to run test') args = parser.parse_args() - train_agent() + main() diff --git a/examples/DQN/utils.py b/examples/DQN/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b938819c9db56a032abb555ef5e78b240c82a6c7 --- /dev/null +++ b/examples/DQN/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +from atari import AtariPlayer +from atari_wrapper import FrameStack, MapState, FireResetEnv + + +def get_player(rom, + image_size, + viz=False, + train=False, + frame_skip=1, + context_len=1): + env = AtariPlayer( + rom, + frame_skip=frame_skip, + viz=viz, + live_lost_as_eoe=train, + max_num_frames=60000) + env = FireResetEnv(env) + env = MapState(env, lambda im: cv2.resize(im, image_size)) + if not train: + # in training, context is taken care of in expreplay buffer + env = FrameStack(env, context_len) + return env diff --git a/examples/QuickStart/README.md b/examples/QuickStart/README.md index 0bd8241a5e129c5cce75537c3f4cc335c7e2e3dd..c3558e42723a60abda0bfde6a27ce8fac8c3eb82 100644 --- a/examples/QuickStart/README.md +++ b/examples/QuickStart/README.md @@ -24,6 +24,7 @@ pip install . cd examples/QuickStart/ python train.py # Or visualize when evaluating: python train.py --eval_vis +``` ### Result After training, you will see the agent get the best score (200 points). diff --git a/examples/QuickStart/cartpole_agent.py b/examples/QuickStart/cartpole_agent.py index e5431e2b7af0b6dfb93831d34417fa8d9d4c6921..a1edad7dc8bf98ceba09e10fa92df808debb2d95 100644 --- a/examples/QuickStart/cartpole_agent.py +++ b/examples/QuickStart/cartpole_agent.py @@ -19,15 +19,19 @@ from parl.framework.agent_base import Agent class CartpoleAgent(Agent): - def __init__(self, algorithm, obs_dim, act_dim): + def __init__(self, algorithm, obs_dim, act_dim, seed=1): self.obs_dim = obs_dim self.act_dim = act_dim + self.seed = seed super(CartpoleAgent, self).__init__(algorithm) def build_program(self): self.pred_program = fluid.Program() self.train_program = fluid.Program() + fluid.default_startup_program().random_seed = self.seed + self.train_program.random_seed = self.seed + with fluid.program_guard(self.pred_program): obs = layers.data( name='obs', shape=[self.obs_dim], dtype='float32') diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py index d1e65d6c0803a2951d6845769edea602cd914d3f..7553cd9cb106ef424d87fe3b87e66e9c32469f89 100644 --- a/examples/QuickStart/train.py +++ b/examples/QuickStart/train.py @@ -19,11 +19,13 @@ from cartpole_agent import CartpoleAgent from cartpole_model import CartpoleModel from parl.algorithms import PolicyGradient from parl.utils import logger +from utils import calc_discount_norm_reward OBS_DIM = 4 ACT_DIM = 2 GAMMA = 0.99 LEARNING_RATE = 1e-3 +SEED = 1 def run_train_episode(env, agent): @@ -56,32 +58,21 @@ def run_evaluate_episode(env, agent): return all_reward -def calc_discount_norm_reward(reward_list): - discount_norm_reward = np.zeros_like(reward_list) - - discount_cumulative_reward = 0 - for i in reversed(range(0, len(reward_list))): - discount_cumulative_reward = ( - GAMMA * discount_cumulative_reward + reward_list[i]) - discount_norm_reward[i] = discount_cumulative_reward - discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward) - discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward) - return discount_norm_reward - - def main(): env = gym.make("CartPole-v0") + env.seed(SEED) + np.random.seed(SEED) model = CartpoleModel(act_dim=ACT_DIM) alg = PolicyGradient(model, hyperparas={'lr': LEARNING_RATE}) - agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) + agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM, seed=SEED) - for i in range(500): + for i in range(1000): obs_list, action_list, reward_list = run_train_episode(env, agent) logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list))) batch_obs = np.array(obs_list) batch_action = np.array(action_list) - batch_reward = calc_discount_norm_reward(reward_list) + batch_reward = calc_discount_norm_reward(reward_list, GAMMA) agent.learn(batch_obs, batch_action, batch_reward) if (i + 1) % 100 == 0: diff --git a/examples/QuickStart/utils.py b/examples/QuickStart/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f200b8f366cb8f050344d39971a31dc9496bb7be --- /dev/null +++ b/examples/QuickStart/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def calc_discount_norm_reward(reward_list, gamma): + discount_norm_reward = np.zeros_like(reward_list) + + discount_cumulative_reward = 0 + for i in reversed(range(0, len(reward_list))): + discount_cumulative_reward = ( + gamma * discount_cumulative_reward + reward_list[i]) + discount_norm_reward[i] = discount_cumulative_reward + discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward) + discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward) + return discount_norm_reward diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py index 8dec5ecdbd3c7b781b101a8cbe05f2f3dea596c6..182e401897ba73a7d2d104e09aa7f4954e6f5e04 100644 --- a/parl/algorithms/__init__.py +++ b/parl/algorithms/__init__.py @@ -14,3 +14,4 @@ from parl.algorithms.dqn import * from parl.algorithms.policy_gradient import * +from parl.algorithms.ddpg import * diff --git a/parl/algorithms/ddpg.py b/parl/algorithms/ddpg.py new file mode 100644 index 0000000000000000000000000000000000000000..5e66d085c9eefadcf2a4900cb80e4aed6f70cb8d --- /dev/null +++ b/parl/algorithms/ddpg.py @@ -0,0 +1,77 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl.layers as layers +from copy import deepcopy +from paddle import fluid +from parl.framework.algorithm_base import Algorithm + +__all__ = ['DDPG'] + + +class DDPG(Algorithm): + def __init__(self, model, hyperparas): + """ model: should implement the function get_actor_params() + """ + Algorithm.__init__(self, model, hyperparas) + self.model = model + self.target_model = deepcopy(model) + + # fetch hyper parameters + self.gamma = hyperparas['gamma'] + self.tau = hyperparas['tau'] + self.actor_lr = hyperparas['actor_lr'] + self.critic_lr = hyperparas['critic_lr'] + + def define_predict(self, obs): + """ use actor model of self.model to predict the action + """ + return self.model.policy(obs) + + def define_learn(self, obs, action, reward, next_obs, terminal): + """ update actor and critic model with DDPG algorithm + """ + actor_cost = self._actor_learn(obs) + critic_cost = self._critic_learn(obs, action, reward, next_obs, + terminal) + return actor_cost, critic_cost + + def _actor_learn(self, obs): + action = self.model.policy(obs) + Q = self.model.value(obs, action) + cost = layers.reduce_mean(-1.0 * Q) + optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr) + optimizer.minimize(cost, parameter_list=self.model.get_actor_params()) + return cost + + def _critic_learn(self, obs, action, reward, next_obs, terminal): + next_action = self.target_model.policy(next_obs) + next_Q = self.target_model.value(next_obs, next_action) + + terminal = layers.cast(terminal, dtype='float32') + target_Q = reward + (1.0 - terminal) * self.gamma * next_Q + target_Q.stop_gradient = True + + Q = self.model.value(obs, action) + cost = layers.square_error_cost(Q, target_Q) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr) + optimizer.minimize(cost) + return cost + + def sync_target(self, gpu_id, decay=None): + if decay is None: + decay = 1.0 - self.tau + self.model.sync_params_to( + self.target_model, gpu_id=gpu_id, decay=decay) diff --git a/parl/utils/gputils.py b/parl/utils/gputils.py index e0056b5e08209d9009db65de975ddf6bccc42378..92ea83e505513ac80125056e217c4c37942689cf 100644 --- a/parl/utils/gputils.py +++ b/parl/utils/gputils.py @@ -32,19 +32,21 @@ def get_gpu_count(): if env_cuda_devices is not None: assert isinstance(env_cuda_devices, str) try: + if not env_cuda_devices: + return 0 gpu_count = len( [x for x in env_cuda_devices.split(',') if int(x) >= 0]) logger.info( 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) - except Exception as e: - logger.error(e.message) + except: + logger.warn('Cannot find available GPU devices, using CPU now.') gpu_count = 0 else: try: gpu_count = str(subprocess.check_output(["nvidia-smi", "-L"])).count('UUID') logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count)) - except Exception as e: - logger.error(e.message) + except: + logger.warn('Cannot find available GPU devices, using CPU now.') gpu_count = 0 return gpu_count