diff --git a/examples/DQN/README.md b/examples/DQN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..570ae9619ed754fbf146dfdf86248681a7ee5d46 --- /dev/null +++ b/examples/DQN/README.md @@ -0,0 +1,27 @@ +## Reproduce DQN with PARL +Based on PARL, the DQN model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game. + ++ DQN in +[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) + +### Atari games introduction +Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game. + + +## How to use +### Dependencies: ++ python2.7 or python3.5+ ++ [PARL](https://github.com/PaddlePaddle/PARL) ++ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle) ++ gym ++ tqdm ++ opencv-python ++ ale_python_interface + + +### Start Training: +``` +# To train a model for Pong game with CUDA +python train.py --rom ./rom_files/pong.bin --use_cuda +``` +> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms). diff --git a/examples/DQN/atari.py b/examples/DQN/atari.py new file mode 100644 index 0000000000000000000000000000000000000000..35a6801e9a04af919ad20756b353711caa99ad33 --- /dev/null +++ b/examples/DQN/atari.py @@ -0,0 +1,168 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import gym +import numpy as np +import os +import threading +from atari_py import ALEInterface +from gym import spaces +from gym.envs.atari.atari_env import ACTION_MEANING + +__all__ = ['AtariPlayer'] + +ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms" +_ALE_LOCK = threading.Lock() +""" +The following AtariPlayer are copied or modified from tensorpack/tensorpack: + https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py +""" + + +class AtariPlayer(gym.Env): + """ + A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. + Info: + score: the accumulated reward in the current game + gameOver: True when the current game is Over + """ + + def __init__(self, + rom_file, + viz=0, + frame_skip=4, + nullop_start=30, + live_lost_as_eoe=True, + max_num_frames=0): + """ + Args: + rom_file: path to the rom + frame_skip: skip every k frames and repeat the action + viz: visualization to be done. + Set to 0 to disable. + Set to a positive number to be the delay between frames to show. + Set to a string to be a directory to store frames. + nullop_start: start with random number of null ops. + live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. + max_num_frames: maximum number of frames per episode. + """ + super(AtariPlayer, self).__init__() + assert os.path.isfile(rom_file), \ + "rom {} not found. Please download at {}".format(rom_file, ROM_URL) + + try: + ALEInterface.setLoggerMode(ALEInterface.Logger.Error) + except AttributeError: + print("You're not using latest ALE") + + # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 + with _ALE_LOCK: + self.ale = ALEInterface() + self.ale.setInt(b"random_seed", np.random.randint(0, 30000)) + self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) + self.ale.setBool(b"showinfo", False) + + self.ale.setInt(b"frame_skip", 1) + self.ale.setBool(b'color_averaging', False) + # manual.pdf suggests otherwise. + self.ale.setFloat(b'repeat_action_probability', 0.0) + + # viz setup + if isinstance(viz, str): + assert os.path.isdir(viz), viz + self.ale.setString(b'record_screen_dir', viz) + viz = 0 + if isinstance(viz, int): + viz = float(viz) + self.viz = viz + if self.viz and isinstance(self.viz, float): + self.windowname = os.path.basename(rom_file) + cv2.startWindowThread() + cv2.namedWindow(self.windowname) + + self.ale.loadROM(rom_file.encode('utf-8')) + self.width, self.height = self.ale.getScreenDims() + self.actions = self.ale.getMinimalActionSet() + + self.live_lost_as_eoe = live_lost_as_eoe + self.frame_skip = frame_skip + self.nullop_start = nullop_start + + self.action_space = spaces.Discrete(len(self.actions)) + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) + self._restart_episode() + + def get_action_meanings(self): + return [ACTION_MEANING[i] for i in self.actions] + + def _grab_raw_image(self): + """ + :returns: the current 3-channel image + """ + m = self.ale.getScreenRGB() + return m.reshape((self.height, self.width, 3)) + + def _current_state(self): + """ + returns: a gray-scale (h, w) uint8 image + """ + ret = self._grab_raw_image() + # avoid missing frame issue: max-pooled over the last screen + ret = np.maximum(ret, self.last_raw_screen) + if self.viz: + if isinstance(self.viz, float): + cv2.imshow(self.windowname, ret) + cv2.waitKey(int(self.viz * 1000)) + ret = ret.astype('float32') + # 0.299,0.587.0.114. same as rgb2y in torch/image + ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) + return ret.astype('uint8') # to save some memory + + def _restart_episode(self): + with _ALE_LOCK: + self.ale.reset_game() + + # random null-ops start + n = np.random.randint(self.nullop_start) + self.last_raw_screen = self._grab_raw_image() + for k in range(n): + if k == n - 1: + self.last_raw_screen = self._grab_raw_image() + self.ale.act(0) + + def reset(self): + if self.ale.game_over(): + self._restart_episode() + return self._current_state() + + def step(self, act): + oldlives = self.ale.lives() + r = 0 + for k in range(self.frame_skip): + if k == self.frame_skip - 1: + self.last_raw_screen = self._grab_raw_image() + r += self.ale.act(self.actions[act]) + newlives = self.ale.lives() + if self.ale.game_over() or \ + (self.live_lost_as_eoe and newlives < oldlives): + break + + isOver = self.ale.game_over() + if self.live_lost_as_eoe: + isOver = isOver or newlives < oldlives + + info = {'ale.lives': newlives} + return self._current_state(), r, isOver, info diff --git a/examples/DQN/atari_agent.py b/examples/DQN/atari_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdf3c0e4c42c11b80b00f006a041c1708364805 --- /dev/null +++ b/examples/DQN/atari_agent.py @@ -0,0 +1,103 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.agent_base import Agent +from parl.utils import logger + +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 + + +class AtariAgent(Agent): + def __init__(self, algorithm, action_dim): + super(AtariAgent, self).__init__(algorithm) + + self.exploration = 1.1 + self.action_dim = action_dim + self.global_step = 0 + self.update_target_steps = 10000 // 4 + + def build_program(self): + self.pred_program = fluid.Program() + self.train_program = fluid.Program() + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + self.value = self.alg.define_predict(obs) + + with fluid.program_guard(self.train_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + action = layers.data(name='act', shape=[1], dtype='int32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.cost = self.alg.define_learn(obs, action, reward, next_obs, + terminal) + + def sample(self, obs): + sample = np.random.random() + if sample < self.exploration: + act = np.random.randint(self.action_dim) + else: + if np.random.random() < 0.01: + act = np.random.randint(self.action_dim) + else: + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + self.exploration = max(0.1, self.exploration - 1e-6) + return act + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + return act + + def learn(self, obs, act, reward, next_obs, terminal): + if self.global_step % self.update_target_steps == 0: + self.alg.sync_target(self.gpu_id) + self.global_step += 1 + + act = np.expand_dims(act, -1) + reward = np.clip(reward, -1, 1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int32'), + 'reward': reward, + 'next_obs': next_obs.astype('float32'), + 'terminal': terminal + } + cost = self.fluid_executor.run( + self.train_program, feed=feed, fetch_list=[self.cost])[0] + return cost diff --git a/examples/DQN/atari_model.py b/examples/DQN/atari_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea1de97a251dbd4651c794bfe3b6a2bf6b2ae6f --- /dev/null +++ b/examples/DQN/atari_model.py @@ -0,0 +1,51 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.model_base import Model +from parl.utils import logger + + +class AtariModel(Model): + def __init__(self, img_height, img_width, act_dim): + self.img_height = img_height + self.img_width = img_width + self.act_dim = act_dim + + self.conv1 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv2 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv3 = layers.conv2d( + num_filters=64, filter_size=4, stride=1, padding=1, act='relu') + self.conv4 = layers.conv2d( + num_filters=64, filter_size=3, stride=1, padding=1, act='relu') + self.fc1 = layers.fc(size=act_dim) + + def value(self, obs): + obs = obs / 255.0 + out = self.conv1(obs) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv2(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv3(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv4(out) + out = layers.flatten(out, axis=1) + out = self.fc1(out) + return out diff --git a/examples/DQN/atari_wrapper.py b/examples/DQN/atari_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4195075faab5eb480942a8afdbd393183fac1fca --- /dev/null +++ b/examples/DQN/atari_wrapper.py @@ -0,0 +1,115 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +from collections import deque +from gym import spaces + +_v0, _v1 = gym.__version__.split('.')[:2] +assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__ +""" +The following wrappers are copied or modified from openai/baselines: +https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +""" + + +class MapState(gym.ObservationWrapper): + def __init__(self, env, map_func): + gym.ObservationWrapper.__init__(self, env) + self._func = map_func + + def observation(self, obs): + return self._func(obs) + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Buffer observations and stack across channels (last axis).""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + chan = 1 if len(shp) == 2 else shp[2] + self.observation_space = spaces.Box( + low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8) + + def reset(self): + """Clear buffer and re-fill by duplicating the first observation.""" + ob = self.env.reset() + for _ in range(self.k - 1): + self.frames.append(np.zeros_like(ob)) + self.frames.append(ob) + return self.observation() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self.observation(), reward, done, info + + def observation(self): + assert len(self.frames) == self.k + return np.stack(self.frames, axis=0) + + +class _FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self): + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs + + def step(self, action): + return self.env.step(action) + + +def FireResetEnv(env): + if isinstance(env, gym.Wrapper): + baseenv = env.unwrapped + else: + baseenv = env + if 'FIRE' in baseenv.get_action_meanings(): + return _FireResetEnv(env) + return env + + +class LimitLength(gym.Wrapper): + def __init__(self, env, k): + gym.Wrapper.__init__(self, env) + self.k = k + + def reset(self): + # This assumes that reset() will really reset the env. + # If the underlying env tries to be smart about reset + # (e.g. end-of-life), the assumption doesn't hold. + ob = self.env.reset() + self.cnt = 0 + return ob + + def step(self, action): + ob, r, done, info = self.env.step(action) + self.cnt += 1 + if self.cnt == self.k: + done = True + return ob, r, done, info diff --git a/examples/DQN/expreplay.py b/examples/DQN/expreplay.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7e2ba25bed33ab538ca6d7d383f588af651975 --- /dev/null +++ b/examples/DQN/expreplay.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import copy +from collections import deque, namedtuple + +Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) + + +class ReplayMemory(object): + def __init__(self, max_size, state_shape, context_len): + self.max_size = int(max_size) + self.state_shape = state_shape + self.context_len = int(context_len) + + self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8') + self.action = np.zeros((self.max_size, ), dtype='int32') + self.reward = np.zeros((self.max_size, ), dtype='float32') + self.isOver = np.zeros((self.max_size, ), dtype='bool') + + self._curr_size = 0 + self._curr_pos = 0 + self._context = deque(maxlen=context_len - 1) + + def append(self, exp): + """append a new experience into replay memory + """ + if self._curr_size < self.max_size: + self._assign(self._curr_pos, exp) + self._curr_size += 1 + else: + self._assign(self._curr_pos, exp) + self._curr_pos = (self._curr_pos + 1) % self.max_size + if exp.isOver: + self._context.clear() + else: + self._context.append(exp) + + def recent_state(self): + """ maintain recent state for training""" + lst = list(self._context) + states = [np.zeros(self.state_shape, dtype='uint8')] * \ + (self._context.maxlen - len(lst)) + states.extend([k.state for k in lst]) + return states + + def sample(self, idx): + """ return state, action, reward, isOver, + note that some frames in state may be generated from last episode, + they should be removed from state + """ + state = np.zeros( + (self.context_len + 1, ) + self.state_shape, dtype=np.uint8) + state_idx = np.arange(idx, + idx + self.context_len + 1) % self._curr_size + + # confirm that no frame was generated from last episode + has_last_episode = False + for k in range(self.context_len - 2, -1, -1): + to_check_idx = state_idx[k] + if self.isOver[to_check_idx]: + has_last_episode = True + state_idx = state_idx[k + 1:] + state[k + 1:] = self.state[state_idx] + break + + if not has_last_episode: + state = self.state[state_idx] + + real_idx = (idx + self.context_len - 1) % self._curr_size + action = self.action[real_idx] + reward = self.reward[real_idx] + isOver = self.isOver[real_idx] + return state, reward, action, isOver + + def __len__(self): + return self._curr_size + + def _assign(self, pos, exp): + self.state[pos] = exp.state + self.reward[pos] = exp.reward + self.action[pos] = exp.action + self.isOver[pos] = exp.isOver + + def sample_batch(self, batch_size): + """sample a batch from replay memory for training + """ + batch_idx = np.random.randint( + self._curr_size - self.context_len - 1, size=batch_size) + batch_idx = (self._curr_pos + batch_idx) % self._curr_size + batch_exp = [self.sample(i) for i in batch_idx] + return self._process_batch(batch_exp) + + def _process_batch(self, batch_exp): + state = np.asarray([e[0] for e in batch_exp], dtype='uint8') + reward = np.asarray([e[1] for e in batch_exp], dtype='float32') + action = np.asarray([e[2] for e in batch_exp], dtype='int8') + isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') + return [state, action, reward, isOver] diff --git a/examples/DQN/rom_files/breakout.bin b/examples/DQN/rom_files/breakout.bin new file mode 100644 index 0000000000000000000000000000000000000000..abab5a8c0a1890461a11b78d4265f1b794327793 Binary files /dev/null and b/examples/DQN/rom_files/breakout.bin differ diff --git a/examples/DQN/rom_files/pong.bin b/examples/DQN/rom_files/pong.bin new file mode 100644 index 0000000000000000000000000000000000000000..14a5bdfc72548613c059938bdf712efdbb5d3806 Binary files /dev/null and b/examples/DQN/rom_files/pong.bin differ diff --git a/examples/DQN/train.py b/examples/DQN/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2d8d48ccd37be367d346f3acd665eacaae5d37 --- /dev/null +++ b/examples/DQN/train.py @@ -0,0 +1,164 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import cv2 +import gym +import paddle.fluid as fluid +import numpy as np +import os +from atari import AtariPlayer +from atari_agent import AtariAgent +from atari_model import AtariModel +from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength +from collections import deque +from datetime import datetime +from expreplay import ReplayMemory, Experience +from parl.algorithms import DQNAlgorithm +from parl.utils import logger +from tqdm import tqdm + +MEMORY_SIZE = 1e6 +MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 +ACTION_REPEAT = 4 # aka FRAME_SKIP +UPDATE_FREQ = 4 +GAMMA = 0.99 +LEARNING_RATE = 1e-3 + + +def run_train_episode(agent, env, exp): + total_reward = 0 + all_cost = [] + state = env.reset() + step = 0 + while True: + step += 1 + context = exp.recent_state() + context.append(state) + context = np.stack(context, axis=0) + action = agent.sample(context) + next_state, reward, isOver, _ = env.step(action) + exp.append(Experience(state, action, reward, isOver)) + # start training + if len(exp) > MEMORY_WARMUP_SIZE: + if step % UPDATE_FREQ == 0: + batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch( + args.batch_size) + batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] + batch_next_state = batch_all_state[:, 1:, :, :] + cost = agent.learn(batch_state, batch_action, batch_reward, + batch_next_state, batch_isOver) + all_cost.append(float(cost)) + total_reward += reward + state = next_state + if isOver: + break + logger.info('[Train]total_reward: {}, mean_cost: {}'.format( + total_reward, np.mean(all_cost))) + return total_reward, step + + +def get_player(rom, viz=False, train=False): + env = AtariPlayer( + rom, + frame_skip=ACTION_REPEAT, + viz=viz, + live_lost_as_eoe=train, + max_num_frames=60000) + env = FireResetEnv(env) + env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) + if not train: + # in training, context is taken care of in expreplay buffer + env = FrameStack(env, CONTEXT_LEN) + return env + + +def eval_agent(agent, env): + episode_reward = [] + for _ in tqdm(range(30), desc='eval agent'): + state = env.reset() + total_reward = 0 + step = 0 + while True: + step += 1 + action = agent.predict(state) + state, reward, isOver, info = env.step(action) + total_reward += reward + if isOver: + break + episode_reward.append(total_reward) + eval_reward = np.mean(episode_reward) + return eval_reward + + +def train_agent(): + env = get_player(args.rom, train=True) + test_env = get_player(args.rom) + exp = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) + action_dim = env.action_space.n + + hyperparas = { + 'action_dim': action_dim, + 'lr': LEARNING_RATE, + 'gamma': GAMMA + } + model = AtariModel(IMAGE_SIZE[0], IMAGE_SIZE[1], action_dim) + algorithm = DQNAlgorithm(model, hyperparas) + agent = AtariAgent(algorithm, action_dim) + + with tqdm(total=MEMORY_WARMUP_SIZE) as pbar: + while len(exp) < MEMORY_WARMUP_SIZE: + total_reward, step = run_train_episode(agent, env, exp) + pbar.update(step) + + # train + test_flag = 0 + pbar = tqdm(total=1e8) + recent_100_reward = [] + total_step = 0 + max_reward = None + while True: + # start epoch + total_reward, step = run_train_episode(agent, env, exp) + total_step += step + pbar.set_description('[train]exploration:{}'.format(agent.exploration)) + pbar.update(step) + + if total_step // args.test_every_steps == test_flag: + pbar.write("testing") + eval_reward = eval_agent(agent, test_env) + test_flag += 1 + logger.info( + "eval_agent done, (steps, eval_reward): ({}, {})".format( + total_step, eval_reward)) + + pbar.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--rom', help='atari rom', required=True) + parser.add_argument( + '--use_cuda', action='store_true', help='if set, use cuda') + parser.add_argument( + '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument( + '--test_every_steps', + type=int, + default=100000, + help='every steps number to run test') + args = parser.parse_args() + train_agent() diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e665bc49297fd637fe53df9b6222ab991b0b0576 --- /dev/null +++ b/parl/algorithms/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from parl.algorithms.dqn_algorithm import * diff --git a/parl/algorithms/dqn_algorithm.py b/parl/algorithms/dqn_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..3088a8128241c26bedc35daa782dfaba6d9ba50f --- /dev/null +++ b/parl/algorithms/dqn_algorithm.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from parl.framework.algorithm_base import Algorithm +import parl.layers as layers +import copy + + +class DQNAlgorithm(Algorithm): + def __init__(self, model, hyperparas): + Algorithm.__init__(self, model, hyperparas) + self.model = model + self.target_model = copy.deepcopy(model) + # fetch hyper parameters + self.action_dim = hyperparas['action_dim'] + self.gamma = hyperparas['gamma'] + self.lr = hyperparas['lr'] + + def define_predict(self, obs): + return self.model.value(obs) + + def define_learn(self, obs, action, reward, next_obs, terminal): + pred_value = self.model.value(obs) + #fluid.layers.Print(pred_value, summarize=10, message='pred_value') + next_pred_value = self.target_model.value(next_obs) + #fluid.layers.Print(next_pred_value, summarize=10, message='next_pred_value') + best_v = layers.reduce_max(next_pred_value, dim=1) + best_v.stop_gradient = True + #fluid.layers.Print(best_v, summarize=10, message='best_v') + target = reward + ( + 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v + + #fluid.layers.Print(target, summarize=10, message='target') + action_onehot = layers.one_hot(action, self.action_dim) + action_onehot = layers.cast(action_onehot, dtype='float32') + pred_action_value = layers.reduce_sum( + layers.elementwise_mul(action_onehot, pred_value), dim=1) + #fluid.layers.Print(pred_action_value, summarize=10, message='pred_action_value') + cost = layers.square_error_cost(pred_action_value, target) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam(self.lr * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + return cost + + def sync_target(self, gpu_id): + self.model.sync_params_to(self.target_model, gpu_id=gpu_id) diff --git a/parl/framework/agent_base.py b/parl/framework/agent_base.py index 3c3232792840837f17dc68248f9aede1b6b50623..6fbe81d6c744ae2cfd00b98f8d5372bac8e4e6de 100644 --- a/parl/framework/agent_base.py +++ b/parl/framework/agent_base.py @@ -16,6 +16,7 @@ import paddle.fluid as fluid import parl.layers as layers from parl.framework.algorithm_base import Algorithm from parl.framework.model_base import Model +from parl.utils import get_gpu_count __all__ = ['Agent'] @@ -31,10 +32,23 @@ class Agent(object): c. define a Agent with the algorithm """ - def __init__(self, algorithm): + def __init__(self, algorithm, gpu_id=None): + """ build program and run initialization for default_startup_program + + Created object: + self.alg: parl.framework.Algorithm + self.gpu_id: int + self.fluid_executor: fluid.Executor + """ assert isinstance(algorithm, Algorithm) self.alg = algorithm + self.build_program() + + if gpu_id is None: + gpu_id = 0 if get_gpu_count() > 0 else -1 + self.gpu_id = gpu_id + place = fluid.CUDAPlace(gpu_id) if gpu_id >= 0 else fluid.CPUPlace() self.fluid_executor = fluid.Executor(place) self.fluid_executor.run(fluid.default_startup_program()) diff --git a/parl/framework/tests/agent_base_test.py b/parl/framework/tests/agent_base_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f00485aeaf80308c57ac3e8a894a609b848467cc --- /dev/null +++ b/parl/framework/tests/agent_base_test.py @@ -0,0 +1,83 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl.layers as layers +import unittest +from paddle import fluid +from parl.framework.agent_base import Agent +from parl.framework.algorithm_base import Algorithm +from parl.framework.model_base import Model +from parl.utils import gputils + + +class TestModel(Model): + def __init__(self): + self.fc1 = layers.fc(size=256) + self.fc2 = layers.fc(size=128) + + def policy(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + return out + + +class TestAlgorithm(Algorithm): + def __init__(self, model, hyperparas=None): + super(TestAlgorithm, self).__init__(model, hyperparas) + + def define_predict(self, obs): + return self.model.policy(obs) + + +class TestAgent(Agent): + def __init__(self, algorithm, gpu_id=None): + super(TestAgent, self).__init__(algorithm, gpu_id) + + def build_program(self): + self.predict_program = fluid.Program() + with fluid.program_guard(self.predict_program): + obs = layers.data(name='obs', shape=[10], dtype='float32') + output = self.alg.define_predict(obs) + self.predict_output = [output] + + def predict(self, obs): + output_np = self.fluid_executor.run( + self.predict_program, + feed={'obs': obs}, + fetch_list=self.predict_output)[0] + return output_np + + +class AgentBaseTest(unittest.TestCase): + def setUp(self): + self.model = TestModel() + self.algorithm = TestAlgorithm(self.model) + + def test_agent_with_gpu(self): + if gputils.get_gpu_count() > 0: + agent = TestAgent(self.algorithm, gpu_id=0) + obs = np.random.random([3, 10]).astype('float32') + output_np = agent.predict(obs) + self.assertIsNotNone(output_np) + + def test_agent_with_cpu(self): + agent = TestAgent(self.algorithm, gpu_id=0) + obs = np.random.random([3, 10]).astype('float32') + output_np = agent.predict(obs) + self.assertIsNotNone(output_np) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/utils/gputils.py b/parl/utils/gputils.py index 60b6b84860ea79b229a3220f6ca413acbcdb8d0c..e0056b5e08209d9009db65de975ddf6bccc42378 100644 --- a/parl/utils/gputils.py +++ b/parl/utils/gputils.py @@ -31,15 +31,20 @@ def get_gpu_count(): env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None) if env_cuda_devices is not None: assert isinstance(env_cuda_devices, str) - gpu_count = len(env_cuda_devices.split(',')) - logger.info( - 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) + try: + gpu_count = len( + [x for x in env_cuda_devices.split(',') if int(x) >= 0]) + logger.info( + 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) + except Exception as e: + logger.error(e.message) + gpu_count = 0 else: try: gpu_count = str(subprocess.check_output(["nvidia-smi", "-L"])).count('UUID') logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count)) except Exception as e: - logger.warn(e.message) + logger.error(e.message) gpu_count = 0 return gpu_count