diff --git a/examples/DQN/atari_agent.py b/examples/DQN/atari_agent.py index db8a2ebec6306295800480349316d04872b6386a..4309116b8ca9e3750749f508b2dd673a6757b417 100644 --- a/examples/DQN/atari_agent.py +++ b/examples/DQN/atari_agent.py @@ -17,20 +17,23 @@ import paddle.fluid as fluid import parl from parl import layers +from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler + IMAGE_SIZE = (84, 84) CONTEXT_LEN = 4 class AtariAgent(parl.Agent): - def __init__(self, algorithm, act_dim): + def __init__(self, algorithm, act_dim, start_lr, total_step): super(AtariAgent, self).__init__(algorithm) - assert isinstance(act_dim, int) self.act_dim = act_dim self.exploration = 1.1 self.global_step = 0 self.update_target_steps = 10000 // 4 + self.lr_scheduler = LinearDecayScheduler(start_lr, total_step) + def build_program(self): self.pred_program = fluid.Program() self.learn_program = fluid.Program() @@ -53,8 +56,11 @@ class AtariAgent(parl.Agent): name='next_obs', shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) terminal = layers.data(name='terminal', shape=[], dtype='bool') - self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) + self.cost = self.alg.learn(obs, action, reward, next_obs, terminal, + lr) def sample(self, obs): sample = np.random.random() @@ -89,6 +95,8 @@ class AtariAgent(parl.Agent): self.alg.sync_target() self.global_step += 1 + lr = self.lr_scheduler.step(step_num=obs.shape[0]) + act = np.expand_dims(act, -1) reward = np.clip(reward, -1, 1) feed = { @@ -96,7 +104,8 @@ class AtariAgent(parl.Agent): 'act': act.astype('int32'), 'reward': reward, 'next_obs': next_obs.astype('float32'), - 'terminal': terminal + 'terminal': terminal, + 'lr': lr } cost = self.fluid_executor.run( self.learn_program, feed=feed, fetch_list=[self.cost])[0] diff --git a/examples/DQN/atari_model.py b/examples/DQN/atari_model.py index 00a3c0639a1169767d20106428d8f47ad27de8cd..a5836ead533718e60229910e330d760f814af48b 100644 --- a/examples/DQN/atari_model.py +++ b/examples/DQN/atari_model.py @@ -18,7 +18,7 @@ from parl import layers class AtariModel(parl.Model): - def __init__(self, act_dim): + def __init__(self, act_dim, algo='DQN'): self.act_dim = act_dim self.conv1 = layers.conv2d( @@ -29,7 +29,15 @@ class AtariModel(parl.Model): num_filters=64, filter_size=4, stride=1, padding=1, act='relu') self.conv4 = layers.conv2d( num_filters=64, filter_size=3, stride=1, padding=1, act='relu') - self.fc1 = layers.fc(size=act_dim) + + self.algo = algo + if algo == 'Dueling': + self.fc1_adv = layers.fc(size=512, act='relu') + self.fc2_adv = layers.fc(size=act_dim) + self.fc1_val = layers.fc(size=512, act='relu') + self.fc2_val = layers.fc(size=1) + else: + self.fc1 = layers.fc(size=act_dim) def value(self, obs): obs = obs / 255.0 @@ -44,5 +52,11 @@ class AtariModel(parl.Model): input=out, pool_size=2, pool_stride=2, pool_type='max') out = self.conv4(out) out = layers.flatten(out, axis=1) - out = self.fc1(out) - return out + + if self.algo == 'Dueling': + As = self.fc2_adv(self.fc1_adv(out)) + V = self.fc2_val(self.fc1_val(out)) + Q = As + (V - layers.reduce_mean(As, dim=1, keep_dim=True)) + else: + Q = self.fc1(out) + return Q diff --git a/examples/DQN/train.py b/examples/DQN/train.py index 4343511ea0106bd494988bb8988024a31069728d..a350f785a19816c91c7cc230876ec35cf1e52f6a 100644 --- a/examples/DQN/train.py +++ b/examples/DQN/train.py @@ -20,10 +20,9 @@ import os import parl from atari_agent import AtariAgent from atari_model import AtariModel -from collections import deque from datetime import datetime from replay_memory import ReplayMemory, Experience -from parl.utils import logger +from parl.utils import tensorboard, logger from tqdm import tqdm from utils import get_player @@ -34,7 +33,7 @@ CONTEXT_LEN = 4 FRAME_SKIP = 4 UPDATE_FREQ = 4 GAMMA = 0.99 -LEARNING_RATE = 1e-3 * 0.5 +LEARNING_RATE = 3e-4 def run_train_episode(env, agent, rpm): @@ -67,7 +66,7 @@ def run_train_episode(env, agent, rpm): if all_cost: logger.info('[Train]total_reward: {}, mean_cost: {}'.format( total_reward, np.mean(all_cost))) - return total_reward, steps + return total_reward, steps, np.mean(all_cost) def run_evaluate_episode(env, agent): @@ -93,27 +92,38 @@ def main(): rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) act_dim = env.action_space.n - model = AtariModel(act_dim) - algorithm = parl.algorithms.DQN( - model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE) - agent = AtariAgent(algorithm, act_dim=act_dim) - - with tqdm(total=MEMORY_WARMUP_SIZE) as pbar: + model = AtariModel(act_dim, args.algo) + if args.algo == 'Double': + algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA) + elif args.algo in ['DQN', 'Dueling']: + algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA) + agent = AtariAgent( + algorithm, + act_dim=act_dim, + start_lr=LEARNING_RATE, + total_step=args.train_total_steps) + + with tqdm( + total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar: while rpm.size() < MEMORY_WARMUP_SIZE: - total_reward, steps = run_train_episode(env, agent, rpm) + total_reward, steps, _ = run_train_episode(env, agent, rpm) pbar.update(steps) # train test_flag = 0 pbar = tqdm(total=args.train_total_steps) - recent_100_reward = [] total_steps = 0 max_reward = None while total_steps < args.train_total_steps: # start epoch - total_reward, steps = run_train_episode(env, agent, rpm) + total_reward, steps, loss = run_train_episode(env, agent, rpm) total_steps += steps pbar.set_description('[train]exploration:{}'.format(agent.exploration)) + tensorboard.add_scalar('dqn/score', total_reward, total_steps) + tensorboard.add_scalar('dqn/loss', loss, + total_steps) # mean of total loss + tensorboard.add_scalar('dqn/exploration', agent.exploration, + total_steps) pbar.update(steps) if total_steps // args.test_every_steps >= test_flag: @@ -127,6 +137,8 @@ def main(): logger.info( "eval_agent done, (steps, eval_reward): ({}, {})".format( total_steps, np.mean(eval_rewards))) + eval_test = np.mean(eval_rewards) + tensorboard.add_scalar('dqn/eval', eval_test, total_steps) pbar.close() @@ -137,10 +149,16 @@ if __name__ == '__main__': '--rom', help='path of the rom of the atari game', required=True) parser.add_argument( '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument( + '--algo', + default='DQN', + help= + 'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively', + ) parser.add_argument( '--train_total_steps', type=int, - default=int(1e8), + default=int(1e7), help='maximum environmental steps of games') parser.add_argument( '--test_every_steps', @@ -149,5 +167,4 @@ if __name__ == '__main__': help='the step interval between two consecutive evaluations') args = parser.parse_args() - main() diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py index b0afd592cdfd8ae03b7dad94c4af2596195cf98a..bcf8c17c938133005a0da6a298076f2d157614bc 100644 --- a/parl/algorithms/fluid/__init__.py +++ b/parl/algorithms/fluid/__init__.py @@ -15,6 +15,7 @@ from parl.algorithms.fluid.a3c import * from parl.algorithms.fluid.ddpg import * from parl.algorithms.fluid.dqn import * +from parl.algorithms.fluid.ddqn import * from parl.algorithms.fluid.policy_gradient import * from parl.algorithms.fluid.ppo import * from parl.algorithms.fluid.impala.impala import * diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ccd4aaafe78d6b698fb04711cdc6b7df48faac8 --- /dev/null +++ b/parl/algorithms/fluid/ddqn.py @@ -0,0 +1,96 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +import copy +import numpy as np +import paddle.fluid as fluid +from parl.core.fluid.algorithm import Algorithm +from parl.core.fluid import layers + + +class DDQN(Algorithm): + def __init__( + self, + model, + act_dim=None, + gamma=None, + ): + """ Double DQN algorithm + + Args: + model (parl.Model): model defining forward network of Q function. + gamma (float): discounted factor for reward computation. + """ + self.model = model + self.target_model = copy.deepcopy(model) + + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + + self.act_dim = act_dim + self.gamma = gamma + + def predict(self, obs): + return self.model.value(obs) + + def learn(self, obs, action, reward, next_obs, terminal, learning_rate): + pred_value = self.model.value(obs) + action_onehot = layers.one_hot(action, self.act_dim) + action_onehot = layers.cast(action_onehot, dtype='float32') + pred_action_value = layers.reduce_sum( + layers.elementwise_mul(action_onehot, pred_value), dim=1) + + # choose acc. to behavior network + next_action_value = self.model.value(next_obs) + greedy_action = layers.argmax(next_action_value, axis=-1) + + # calculate the target q value with target network + batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int') + range_tmp = layers.range( + start=0, end=batch_size, step=1, dtype='int64') * self.act_dim + a_indices = range_tmp + greedy_action + a_indices = layers.cast(a_indices, dtype='int32') + next_pred_value = self.target_model.value(next_obs) + next_pred_value = layers.reshape( + next_pred_value, shape=[ + -1, + ]) + max_v = layers.gather(next_pred_value, a_indices) + max_v = layers.reshape( + max_v, shape=[ + -1, + ]) + max_v.stop_gradient = True + + target = reward + ( + 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * max_v + cost = layers.square_error_cost(pred_action_value, target) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, epsilon=1e-3) + optimizer.minimize(cost) + return cost + + def sync_target(self, gpu_id=None): + """ sync weights of self.model to self.target_model + """ + if gpu_id is not None: + warnings.warn( + "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", + DeprecationWarning, + stacklevel=2) + self.model.sync_weights_to(self.target_model) diff --git a/parl/algorithms/fluid/dqn.py b/parl/algorithms/fluid/dqn.py index 94bc92ec520aae9a25678542cc52bd8f3f7650e2..e6e97577d041f77b1899ce460582c29f5bf480a8 100644 --- a/parl/algorithms/fluid/dqn.py +++ b/parl/algorithms/fluid/dqn.py @@ -25,12 +25,7 @@ __all__ = ['DQN'] class DQN(Algorithm): - def __init__(self, - model, - hyperparas=None, - act_dim=None, - gamma=None, - lr=None): + def __init__(self, model, hyperparas=None, act_dim=None, gamma=None): """ DQN algorithm Args: @@ -50,14 +45,11 @@ class DQN(Algorithm): stacklevel=2) self.act_dim = hyperparas['action_dim'] self.gamma = hyperparas['gamma'] - self.lr = hyperparas['lr'] else: assert isinstance(act_dim, int) assert isinstance(gamma, float) - assert isinstance(lr, float) self.act_dim = act_dim self.gamma = gamma - self.lr = lr @deprecated( deprecated_in='1.2', removed_in='1.3', replace_function='predict') @@ -73,10 +65,12 @@ class DQN(Algorithm): @deprecated( deprecated_in='1.2', removed_in='1.3', replace_function='learn') - def define_learn(self, obs, action, reward, next_obs, terminal): - return self.learn(obs, action, reward, next_obs, terminal) + def define_learn(self, obs, action, reward, next_obs, terminal, + learning_rate): + return self.learn(obs, action, reward, next_obs, terminal, + learning_rate) - def learn(self, obs, action, reward, next_obs, terminal): + def learn(self, obs, action, reward, next_obs, terminal, learning_rate): """ update value model self.model with DQN algorithm """ @@ -93,7 +87,8 @@ class DQN(Algorithm): layers.elementwise_mul(action_onehot, pred_value), dim=1) cost = layers.square_error_cost(pred_action_value, target) cost = layers.reduce_mean(cost) - optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3) + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, epsilon=1e-3) optimizer.minimize(cost) return cost