提交 4a4366a5 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

DQN example (#33)

* add DQN example, add Agent unittest

* refine readme

* refine  code

* simplify code
上级 5be4ca00
## Reproduce DQN with PARL
Based on PARL, the DQN model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game.
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
## How to use
### Dependencies:
+ python2.7 or python3.5+
+ [PARL](https://github.com/PaddlePaddle/PARL)
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
+ tqdm
+ opencv-python
+ ale_python_interface
### Start Training:
```
# To train a model for Pong game with CUDA
python train.py --rom ./rom_files/pong.bin --use_cuda
```
> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import gym
import numpy as np
import os
import threading
from atari_py import ALEInterface
from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
__all__ = ['AtariPlayer']
ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock()
"""
The following AtariPlayer are copied or modified from tensorpack/tensorpack:
https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py
"""
class AtariPlayer(gym.Env):
"""
A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
"""
def __init__(self,
rom_file,
viz=0,
frame_skip=4,
nullop_start=30,
live_lost_as_eoe=True,
max_num_frames=0):
"""
Args:
rom_file: path to the rom
frame_skip: skip every k frames and repeat the action
viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
nullop_start: start with random number of null ops.
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode.
"""
super(AtariPlayer, self).__init__()
assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL)
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
except AttributeError:
print("You're not using latest ALE")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK:
self.ale = ALEInterface()
self.ale.setInt(b"random_seed", np.random.randint(0, 30000))
self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1)
self.ale.setBool(b'color_averaging', False)
# manual.pdf suggests otherwise.
self.ale.setFloat(b'repeat_action_probability', 0.0)
# viz setup
if isinstance(viz, str):
assert os.path.isdir(viz), viz
self.ale.setString(b'record_screen_dir', viz)
viz = 0
if isinstance(viz, int):
viz = float(viz)
self.viz = viz
if self.viz and isinstance(self.viz, float):
self.windowname = os.path.basename(rom_file)
cv2.startWindowThread()
cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file.encode('utf-8'))
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self.actions]
def _grab_raw_image(self):
"""
:returns: the current 3-channel image
"""
m = self.ale.getScreenRGB()
return m.reshape((self.height, self.width, 3))
def _current_state(self):
"""
returns: a gray-scale (h, w) uint8 image
"""
ret = self._grab_raw_image()
# avoid missing frame issue: max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.windowname, ret)
cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
with _ALE_LOCK:
self.ale.reset_game()
# random null-ops start
n = np.random.randint(self.nullop_start)
self.last_raw_screen = self._grab_raw_image()
for k in range(n):
if k == n - 1:
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def reset(self):
if self.ale.game_over():
self._restart_episode()
return self._current_state()
def step(self, act):
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
if k == self.frame_skip - 1:
self.last_raw_screen = self._grab_raw_image()
r += self.ale.act(self.actions[act])
newlives = self.ale.lives()
if self.ale.game_over() or \
(self.live_lost_as_eoe and newlives < oldlives):
break
isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
info = {'ale.lives': newlives}
return self._current_state(), r, isOver, info
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
from parl.utils import logger
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
class AtariAgent(Agent):
def __init__(self, algorithm, action_dim):
super(AtariAgent, self).__init__(algorithm)
self.exploration = 1.1
self.action_dim = action_dim
self.global_step = 0
self.update_target_steps = 10000 // 4
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
self.value = self.alg.define_predict(obs)
with fluid.program_guard(self.train_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.cost = self.alg.define_learn(obs, action, reward, next_obs,
terminal)
def sample(self, obs):
sample = np.random.random()
if sample < self.exploration:
act = np.random.randint(self.action_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.action_dim)
else:
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target(self.gpu_id)
self.global_step += 1
act = np.expand_dims(act, -1)
reward = np.clip(reward, -1, 1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal
}
cost = self.fluid_executor.run(
self.train_program, feed=feed, fetch_list=[self.cost])[0]
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from parl.utils import logger
class AtariModel(Model):
def __init__(self, img_height, img_width, act_dim):
self.img_height = img_height
self.img_width = img_width
self.act_dim = act_dim
self.conv1 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=64, filter_size=4, stride=1, padding=1, act='relu')
self.conv4 = layers.conv2d(
num_filters=64, filter_size=3, stride=1, padding=1, act='relu')
self.fc1 = layers.fc(size=act_dim)
def value(self, obs):
obs = obs / 255.0
out = self.conv1(obs)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv2(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv3(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv4(out)
out = layers.flatten(out, axis=1)
out = self.fc1(out)
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
from collections import deque
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class MapState(gym.ObservationWrapper):
def __init__(self, env, map_func):
gym.ObservationWrapper.__init__(self, env)
self._func = map_func
def observation(self, obs):
return self._func(obs)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self.observation()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self.observation(), reward, done, info
def observation(self):
assert len(self.frames) == self.k
return np.stack(self.frames, axis=0)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import cv2
import gym
import paddle.fluid as fluid
import numpy as np
import os
from atari import AtariPlayer
from atari_agent import AtariAgent
from atari_model import AtariModel
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from collections import deque
from datetime import datetime
from expreplay import ReplayMemory, Experience
from parl.algorithms import DQNAlgorithm
from parl.utils import logger
from tqdm import tqdm
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 1e-3
def run_train_episode(agent, env, exp):
total_reward = 0
all_cost = []
state = env.reset()
step = 0
while True:
step += 1
context = exp.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action)
exp.append(Experience(state, action, reward, isOver))
# start training
if len(exp) > MEMORY_WARMUP_SIZE:
if step % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, step
def get_player(rom, viz=False, train=False):
env = AtariPlayer(
rom,
frame_skip=ACTION_REPEAT,
viz=viz,
live_lost_as_eoe=train,
max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
if not train:
# in training, context is taken care of in expreplay buffer
env = FrameStack(env, CONTEXT_LEN)
return env
def eval_agent(agent, env):
episode_reward = []
for _ in tqdm(range(30), desc='eval agent'):
state = env.reset()
total_reward = 0
step = 0
while True:
step += 1
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
episode_reward.append(total_reward)
eval_reward = np.mean(episode_reward)
return eval_reward
def train_agent():
env = get_player(args.rom, train=True)
test_env = get_player(args.rom)
exp = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
action_dim = env.action_space.n
hyperparas = {
'action_dim': action_dim,
'lr': LEARNING_RATE,
'gamma': GAMMA
}
model = AtariModel(IMAGE_SIZE[0], IMAGE_SIZE[1], action_dim)
algorithm = DQNAlgorithm(model, hyperparas)
agent = AtariAgent(algorithm, action_dim)
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
while len(exp) < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(agent, env, exp)
pbar.update(step)
# train
test_flag = 0
pbar = tqdm(total=1e8)
recent_100_reward = []
total_step = 0
max_reward = None
while True:
# start epoch
total_reward, step = run_train_episode(agent, env, exp)
total_step += step
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
pbar.update(step)
if total_step // args.test_every_steps == test_flag:
pbar.write("testing")
eval_reward = eval_agent(agent, test_env)
test_flag += 1
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_step, eval_reward))
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--rom', help='atari rom', required=True)
parser.add_argument(
'--use_cuda', action='store_true', help='if set, use cuda')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='every steps number to run test')
args = parser.parse_args()
train_agent()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.dqn_algorithm import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from parl.framework.algorithm_base import Algorithm
import parl.layers as layers
import copy
class DQNAlgorithm(Algorithm):
def __init__(self, model, hyperparas):
Algorithm.__init__(self, model, hyperparas)
self.model = model
self.target_model = copy.deepcopy(model)
# fetch hyper parameters
self.action_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
self.lr = hyperparas['lr']
def define_predict(self, obs):
return self.model.value(obs)
def define_learn(self, obs, action, reward, next_obs, terminal):
pred_value = self.model.value(obs)
#fluid.layers.Print(pred_value, summarize=10, message='pred_value')
next_pred_value = self.target_model.value(next_obs)
#fluid.layers.Print(next_pred_value, summarize=10, message='next_pred_value')
best_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True
#fluid.layers.Print(best_v, summarize=10, message='best_v')
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v
#fluid.layers.Print(target, summarize=10, message='target')
action_onehot = layers.one_hot(action, self.action_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
#fluid.layers.Print(pred_action_value, summarize=10, message='pred_action_value')
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id):
self.model.sync_params_to(self.target_model, gpu_id=gpu_id)
......@@ -16,6 +16,7 @@ import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.algorithm_base import Algorithm
from parl.framework.model_base import Model
from parl.utils import get_gpu_count
__all__ = ['Agent']
......@@ -31,10 +32,23 @@ class Agent(object):
c. define a Agent with the algorithm
"""
def __init__(self, algorithm):
def __init__(self, algorithm, gpu_id=None):
""" build program and run initialization for default_startup_program
Created object:
self.alg: parl.framework.Algorithm
self.gpu_id: int
self.fluid_executor: fluid.Executor
"""
assert isinstance(algorithm, Algorithm)
self.alg = algorithm
self.build_program()
if gpu_id is None:
gpu_id = 0 if get_gpu_count() > 0 else -1
self.gpu_id = gpu_id
place = fluid.CUDAPlace(gpu_id) if gpu_id >= 0 else fluid.CPUPlace()
self.fluid_executor = fluid.Executor(place)
self.fluid_executor.run(fluid.default_startup_program())
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
import unittest
from paddle import fluid
from parl.framework.agent_base import Agent
from parl.framework.algorithm_base import Algorithm
from parl.framework.model_base import Model
from parl.utils import gputils
class TestModel(Model):
def __init__(self):
self.fc1 = layers.fc(size=256)
self.fc2 = layers.fc(size=128)
def policy(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
return out
class TestAlgorithm(Algorithm):
def __init__(self, model, hyperparas=None):
super(TestAlgorithm, self).__init__(model, hyperparas)
def define_predict(self, obs):
return self.model.policy(obs)
class TestAgent(Agent):
def __init__(self, algorithm, gpu_id=None):
super(TestAgent, self).__init__(algorithm, gpu_id)
def build_program(self):
self.predict_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.alg.define_predict(obs)
self.predict_output = [output]
def predict(self, obs):
output_np = self.fluid_executor.run(
self.predict_program,
feed={'obs': obs},
fetch_list=self.predict_output)[0]
return output_np
class AgentBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
self.algorithm = TestAlgorithm(self.model)
def test_agent_with_gpu(self):
if gputils.get_gpu_count() > 0:
agent = TestAgent(self.algorithm, gpu_id=0)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
def test_agent_with_cpu(self):
agent = TestAgent(self.algorithm, gpu_id=0)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
if __name__ == '__main__':
unittest.main()
......@@ -31,15 +31,20 @@ def get_gpu_count():
env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env_cuda_devices is not None:
assert isinstance(env_cuda_devices, str)
gpu_count = len(env_cuda_devices.split(','))
try:
gpu_count = len(
[x for x in env_cuda_devices.split(',') if int(x) >= 0])
logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.error(e.message)
gpu_count = 0
else:
try:
gpu_count = str(subprocess.check_output(["nvidia-smi",
"-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.warn(e.message)
logger.error(e.message)
gpu_count = 0
return gpu_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册