From cdd4622a77f9f0e338042b24872b6a5774927bc0 Mon Sep 17 00:00:00 2001 From: Hongsheng Zeng Date: Thu, 6 Dec 2018 21:45:47 -0500 Subject: [PATCH] Add QuickStart example (#35) * add QuickStart example, refine DQN example * add examples link * refine the naming, and add quick start training result --- README.md | 6 +- examples/DQN/README.md | 4 +- examples/DQN/atari_agent.py | 2 +- examples/DQN/atari_model.py | 1 - examples/DQN/train.py | 8 +- examples/QuickStart/README.md | 29 ++++++ examples/QuickStart/cartpole_agent.py | 72 +++++++++++++ examples/QuickStart/cartpole_model.py | 31 ++++++ examples/QuickStart/train.py | 100 +++++++++++++++++++ parl/algorithms/__init__.py | 4 +- parl/algorithms/{dqn_algorithm.py => dqn.py} | 18 ++-- parl/algorithms/policy_gradient.py | 42 ++++++++ 12 files changed, 297 insertions(+), 20 deletions(-) create mode 100644 examples/QuickStart/README.md create mode 100644 examples/QuickStart/cartpole_agent.py create mode 100644 examples/QuickStart/cartpole_model.py create mode 100644 examples/QuickStart/train.py rename parl/algorithms/{dqn_algorithm.py => dqn.py} (80%) create mode 100644 parl/algorithms/policy_gradient.py diff --git a/README.md b/README.md index 7566023..d714a2b 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,8 @@ pip install --upgrade git+https://github.com/PaddlePaddle/PARL.git ``` # Examples - -- DQN +- [QuickStart](examples/QuickStart/) +- [DQN](examples/DQN/) - DDPG - PPO -- Winning Solution for NIPS2018: AI for Prosthetics Challenge +- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) diff --git a/examples/DQN/README.md b/examples/DQN/README.md index 570ae96..2901737 100644 --- a/examples/DQN/README.md +++ b/examples/DQN/README.md @@ -21,7 +21,7 @@ Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari g ### Start Training: ``` -# To train a model for Pong game with CUDA -python train.py --rom ./rom_files/pong.bin --use_cuda +# To train a model for Pong game +python train.py --rom ./rom_files/pong.bin ``` > To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms). diff --git a/examples/DQN/atari_agent.py b/examples/DQN/atari_agent.py index bcdf3c0..2219c62 100644 --- a/examples/DQN/atari_agent.py +++ b/examples/DQN/atari_agent.py @@ -16,7 +16,6 @@ import numpy as np import paddle.fluid as fluid import parl.layers as layers from parl.framework.agent_base import Agent -from parl.utils import logger IMAGE_SIZE = (84, 84) CONTEXT_LEN = 4 @@ -34,6 +33,7 @@ class AtariAgent(Agent): def build_program(self): self.pred_program = fluid.Program() self.train_program = fluid.Program() + with fluid.program_guard(self.pred_program): obs = layers.data( name='obs', diff --git a/examples/DQN/atari_model.py b/examples/DQN/atari_model.py index 0ea1de9..5426ece 100644 --- a/examples/DQN/atari_model.py +++ b/examples/DQN/atari_model.py @@ -15,7 +15,6 @@ import paddle.fluid as fluid import parl.layers as layers from parl.framework.model_base import Model -from parl.utils import logger class AtariModel(Model): diff --git a/examples/DQN/train.py b/examples/DQN/train.py index 3d2d8d4..a7cb2b3 100644 --- a/examples/DQN/train.py +++ b/examples/DQN/train.py @@ -25,7 +25,7 @@ from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength from collections import deque from datetime import datetime from expreplay import ReplayMemory, Experience -from parl.algorithms import DQNAlgorithm +from parl.algorithms import DQN from parl.utils import logger from tqdm import tqdm @@ -36,7 +36,7 @@ CONTEXT_LEN = 4 ACTION_REPEAT = 4 # aka FRAME_SKIP UPDATE_FREQ = 4 GAMMA = 0.99 -LEARNING_RATE = 1e-3 +LEARNING_RATE = 1e-3 * 0.5 def run_train_episode(agent, env, exp): @@ -116,7 +116,7 @@ def train_agent(): 'gamma': GAMMA } model = AtariModel(IMAGE_SIZE[0], IMAGE_SIZE[1], action_dim) - algorithm = DQNAlgorithm(model, hyperparas) + algorithm = DQN(model, hyperparas) agent = AtariAgent(algorithm, action_dim) with tqdm(total=MEMORY_WARMUP_SIZE) as pbar: @@ -151,8 +151,6 @@ def train_agent(): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--rom', help='atari rom', required=True) - parser.add_argument( - '--use_cuda', action='store_true', help='if set, use cuda') parser.add_argument( '--batch_size', type=int, default=64, help='batch size for training') parser.add_argument( diff --git a/examples/QuickStart/README.md b/examples/QuickStart/README.md new file mode 100644 index 0000000..0bd8241 --- /dev/null +++ b/examples/QuickStart/README.md @@ -0,0 +1,29 @@ +## Quick Start Example +Based on PARL, train a agent to play CartPole game with policy gradient algorithm in a few minutes. + +## How to use +### Dependencies: + ++ python2.7 or python3.5+ ++ [PARL](https://github.com/PaddlePaddle/PARL) ++ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle) ++ gym + +### Start Training: +``` +# Install dependencies +pip install paddlepaddle +# Or use Cuda: pip install paddlepaddle-gpu + +pip install gym +git clone https://github.com/PaddlePaddle/PARL.git +cd PARL +pip install . + +# Train model +cd examples/QuickStart/ +python train.py +# Or visualize when evaluating: python train.py --eval_vis + +### Result +After training, you will see the agent get the best score (200 points). diff --git a/examples/QuickStart/cartpole_agent.py b/examples/QuickStart/cartpole_agent.py new file mode 100644 index 0000000..e5431e2 --- /dev/null +++ b/examples/QuickStart/cartpole_agent.py @@ -0,0 +1,72 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.agent_base import Agent + + +class CartpoleAgent(Agent): + def __init__(self, algorithm, obs_dim, act_dim): + self.obs_dim = obs_dim + self.act_dim = act_dim + super(CartpoleAgent, self).__init__(algorithm) + + def build_program(self): + self.pred_program = fluid.Program() + self.train_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.act_prob = self.alg.define_predict(obs) + + with fluid.program_guard(self.train_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data(name='act', shape=[1], dtype='int64') + reward = layers.data(name='reward', shape=[], dtype='float32') + self.cost = self.alg.define_learn(obs, act, reward) + + def sample(self, obs): + obs = np.expand_dims(obs, axis=0) + act_prob = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.act_prob])[0] + act_prob = np.squeeze(act_prob, axis=0) + act = np.random.choice(range(self.act_dim), p=act_prob) + return act + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act_prob = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.act_prob])[0] + act_prob = np.squeeze(act_prob, axis=0) + act = np.argmax(act_prob) + return act + + def learn(self, obs, act, reward): + act = np.expand_dims(act, axis=-1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int64'), + 'reward': reward.astype('float32') + } + cost = self.fluid_executor.run( + self.train_program, feed=feed, fetch_list=[self.cost])[0] + return cost diff --git a/examples/QuickStart/cartpole_model.py b/examples/QuickStart/cartpole_model.py new file mode 100644 index 0000000..7b2fc65 --- /dev/null +++ b/examples/QuickStart/cartpole_model.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.model_base import Model + + +class CartpoleModel(Model): + def __init__(self, act_dim): + act_dim = act_dim + hid1_size = act_dim * 10 + + self.fc1 = layers.fc(size=hid1_size, act='tanh') + self.fc2 = layers.fc(size=act_dim, act='softmax') + + def policy(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + return out diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py new file mode 100644 index 0000000..5fdf660 --- /dev/null +++ b/examples/QuickStart/train.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym +import numpy as np +from cartpole_agent import CartpoleAgent +from cartpole_model import CartpoleModel +from parl.algorithms import PolicyGradient +from parl.utils import logger + +OBS_DIM = 4 +ACT_DIM = 2 +GAMMA = 0.99 +LEARNING_RATE = 1e-3 + + +def run_train_episode(env, agent): + obs_list, action_list, reward_list = [], [], [] + obs = env.reset() + while True: + obs_list.append(obs) + action = agent.sample(obs) + action_list.append(action) + + obs, reward, done, info = env.step(action) + reward_list.append(reward) + + if done: + break + return obs_list, action_list, reward_list + + +def run_evaluate_episode(env, agent): + obs = env.reset() + all_reward = 0 + while True: + if args.eval_vis: + env.render() + action = agent.predict(obs) + obs, reward, done, info = env.step(action) + all_reward += reward + if done: + break + return all_reward + + +def calc_discount_norm_reward(reward_list): + discount_norm_reward = np.zeros_like(reward_list) + + discount_cumulative_reward = 0 + for i in reversed(range(0, len(reward_list))): + discount_cumulative_reward = ( + GAMMA * discount_cumulative_reward + reward_list[i]) + discount_norm_reward[i] = discount_cumulative_reward + discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward) + discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward) + return discount_norm_reward + + +def main(): + env = gym.make("CartPole-v0") + model = CartpoleModel(act_dim=ACT_DIM) + alg = PolicyGradient(model, hyperparas={'lr': LEARNING_RATE}) + agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) + + for i in range(500): + obs_list, action_list, reward_list = run_train_episode(env, agent) + logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list))) + + batch_obs = np.array(obs_list) + batch_action = np.array(action_list) + batch_reward = calc_discount_norm_reward(reward_list) + + agent.learn(batch_obs, batch_action, batch_reward) + if i % 100 == 0: + all_reward = run_evaluate_episode(env, agent) + logger.info('Test reward: {}'.format(all_reward)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--eval_vis', + action='store_true', + help='if set, will visualize the game when evaluating') + args = parser.parse_args() + + main() diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py index e665bc4..8dec5ec 100644 --- a/parl/algorithms/__init__.py +++ b/parl/algorithms/__init__.py @@ -11,4 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from parl.algorithms.dqn_algorithm import * + +from parl.algorithms.dqn import * +from parl.algorithms.policy_gradient import * diff --git a/parl/algorithms/dqn_algorithm.py b/parl/algorithms/dqn.py similarity index 80% rename from parl/algorithms/dqn_algorithm.py rename to parl/algorithms/dqn.py index 3088a81..42345fd 100644 --- a/parl/algorithms/dqn_algorithm.py +++ b/parl/algorithms/dqn.py @@ -17,8 +17,10 @@ from parl.framework.algorithm_base import Algorithm import parl.layers as layers import copy +__all__ = ['DQN'] -class DQNAlgorithm(Algorithm): + +class DQN(Algorithm): def __init__(self, model, hyperparas): Algorithm.__init__(self, model, hyperparas) self.model = model @@ -29,30 +31,32 @@ class DQNAlgorithm(Algorithm): self.lr = hyperparas['lr'] def define_predict(self, obs): + """ use value model self.model to predict the action value + """ return self.model.value(obs) def define_learn(self, obs, action, reward, next_obs, terminal): + """ update value model self.model with DQN algorithm + """ + pred_value = self.model.value(obs) - #fluid.layers.Print(pred_value, summarize=10, message='pred_value') next_pred_value = self.target_model.value(next_obs) - #fluid.layers.Print(next_pred_value, summarize=10, message='next_pred_value') best_v = layers.reduce_max(next_pred_value, dim=1) best_v.stop_gradient = True - #fluid.layers.Print(best_v, summarize=10, message='best_v') target = reward + ( 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v - #fluid.layers.Print(target, summarize=10, message='target') action_onehot = layers.one_hot(action, self.action_dim) action_onehot = layers.cast(action_onehot, dtype='float32') pred_action_value = layers.reduce_sum( layers.elementwise_mul(action_onehot, pred_value), dim=1) - #fluid.layers.Print(pred_action_value, summarize=10, message='pred_action_value') cost = layers.square_error_cost(pred_action_value, target) cost = layers.reduce_mean(cost) - optimizer = fluid.optimizer.Adam(self.lr * 0.5, epsilon=1e-3) + optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3) optimizer.minimize(cost) return cost def sync_target(self, gpu_id): + """ sync parameters of self.target_model with self.model + """ self.model.sync_params_to(self.target_model, gpu_id=gpu_id) diff --git a/parl/algorithms/policy_gradient.py b/parl/algorithms/policy_gradient.py new file mode 100644 index 0000000..68ca44f --- /dev/null +++ b/parl/algorithms/policy_gradient.py @@ -0,0 +1,42 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from parl.framework.algorithm_base import Algorithm +import parl.layers as layers + +__all__ = ['PolicyGradient'] + + +class PolicyGradient(Algorithm): + def __init__(self, model, hyperparas): + Algorithm.__init__(self, model, hyperparas) + self.model = model + self.lr = hyperparas['lr'] + + def define_predict(self, obs): + """ use policy model self.model to predict the action probability + """ + return self.model.policy(obs) + + def define_learn(self, obs, action, reward): + """ update policy model self.model with policy gradient algorithm + """ + act_prob = self.model.policy(obs) + log_prob = layers.cross_entropy(act_prob, action) + cost = log_prob * reward + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam(self.lr) + optimizer.minimize(cost) + return cost -- GitLab