未验证 提交 117b1c38 编写于 作者: L LI Yunxiang 提交者: GitHub

add simple dqn demo (#254)

* add simple dqn

* Update README.md

* Update train.py

* update

* update image in README

* update readme

* simplify

* yapf

* Update README.md

* Update README.md

* Update README.md

* Update train.py

* yapf
上级 0698534b
## Reproduce DQN with PARL ## Reproduce DQN with PARL
Based on PARL, the DQN algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks. Based on PARL, we provide a simple demonstration of DQN.
+ DQN in + DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) [Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction ### Result
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result Performance of DQN playing CartPole-v1
Mean episode rewards for 10 million training steps. <p align="left">
<img src="../QuickStart/performance.gif" alt="result" height="175"/>
<img src=".benchmark/merge.png" width = "1150" height ="230" alt="pong" /> <img src="cartpole.jpg" alt="result" height="175"/>
Performance of DQN on various environments
<p align="center">
<img src=".benchmark/table.png" alt="result" width="700"/>
</p> </p>
## How to use ## How to use
...@@ -25,13 +19,14 @@ Performance of DQN on various environments ...@@ -25,13 +19,14 @@ Performance of DQN on various environments
+ [parl](https://github.com/PaddlePaddle/PARL) + [parl](https://github.com/PaddlePaddle/PARL)
+ gym + gym
+ tqdm + tqdm
+ atari-py
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
### Start Training: ### Start Training:
``` ```
# To train a model for Pong game # To train a model for CartPole-v1 game
python train.py --rom ./rom_files/pong.bin python train.py
``` ```
> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
## DQN-Variants
For DQN variants such as Double DQN and Dueling DQN, please check [here](https://github.com/PaddlePaddle/PARL/tree/develop/examples/DQN_variant)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
class CartpoleAgent(parl.Agent):
def __init__(self,
algorithm,
state_dim,
act_dim,
e_greed=0.1,
e_greed_decrement=0):
assert isinstance(state_dim, int)
assert isinstance(act_dim, int)
self.state_dim = state_dim
self.act_dim = act_dim
super(CartpoleAgent, self).__init__(algorithm)
self.global_step = 0
self.update_target_steps = 200
self.e_greed = e_greed
self.e_greed_decrement = e_greed_decrement
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32')
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.state_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal,
lr)
def sample(self, obs):
sample = np.random.rand()
if sample < self.e_greed:
act = np.random.randint(self.act_dim)
else:
act = self.predict(obs)
self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal, lr):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal,
'lr': np.float32([lr]),
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class CartpoleModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 128
hid2_size = 128
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act=None)
def value(self, obs):
h1 = self.fc1(obs)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
...@@ -12,103 +12,35 @@ ...@@ -12,103 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np # Modified from https://github.com/seungeunrho/minimalRL/blob/master/dqn.py
import copy
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) import random
import collections
import numpy as np
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len): def __init__(self, max_size):
self.max_size = int(max_size) self.buffer = collections.deque(maxlen=max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp): def append(self, exp):
"""append a new experience into replay memory self.buffer.append(exp)
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx): def sample(self, batch_size):
""" return state, action, reward, isOver, mini_batch = random.sample(self.buffer, batch_size)
note that some frames in state may be generated from last episode, state_batch, action_batch, reward_batch, next_state_batch, done_batch = [], [], [], [], []
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode for experience in mini_batch:
has_last_episode = False s, a, r, s_p, done = experience
for k in range(self.context_len - 2, -1, -1): state_batch.append(s)
to_check_idx = state_idx[k] action_batch.append(a)
if self.isOver[to_check_idx]: reward_batch.append(r)
has_last_episode = True next_state_batch.append(s_p)
state_idx = state_idx[k + 1:] done_batch.append(done)
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode: return np.array(state_batch).astype('float32'), \
state = self.state[state_idx] np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_state_batch).astype('float32'), np.array(done_batch).astype('float32')
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self): def __len__(self):
return self._curr_size return len(self.buffer)
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
...@@ -12,160 +12,100 @@ ...@@ -12,160 +12,100 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import gym import gym
import paddle.fluid as fluid
import numpy as np import numpy as np
import os
import parl import parl
from atari_agent import AtariAgent from parl.utils import logger
from atari_model import AtariModel
from datetime import datetime from cartpole_model import CartpoleModel
from replay_memory import ReplayMemory, Experience from cartpole_agent import CartpoleAgent
from parl.utils import tensorboard, logger
from tqdm import tqdm from replay_memory import ReplayMemory
from utils import get_player
LEARN_FREQ = 5 # update parameters every 5 steps
MEMORY_SIZE = 1e6 MEMORY_SIZE = 20000 # replay memory size
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 MEMORY_WARMUP_SIZE = 200 # store some experiences in the replay memory in advance
IMAGE_SIZE = (84, 84) BATCH_SIZE = 32
CONTEXT_LEN = 4 LEARNING_RATE = 0.0005
FRAME_SKIP = 4 GAMMA = 0.99 # discount factor of reward
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 3e-4 def run_episode(agent, env, rpm):
def run_train_episode(env, agent, rpm):
total_reward = 0 total_reward = 0
all_cost = []
state = env.reset() state = env.reset()
steps = 0 step = 0
while True: while True:
steps += 1 step += 1
context = rpm.recent_state() action = agent.sample(state)
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action) next_state, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver)) rpm.append((state, action, reward, next_state, isOver))
# start training
if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps, np.mean(all_cost)
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_state, batch_action, batch_reward, batch_next_state,
batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver,
LEARNING_RATE)
def run_evaluate_episode(env, agent):
state = env.reset()
total_reward = 0
while True:
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward total_reward += reward
state = next_state
if isOver: if isOver:
break break
return total_reward return total_reward
def evaluate(agent, env, render=False):
# test part, run 5 episodes and average
eval_reward = []
for i in range(5):
state = env.reset()
episode_reward = 0
isOver = False
while not isOver:
action = agent.predict(state)
if render:
env.render()
state, reward, isOver, _ = env.step(action)
episode_reward += reward
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main(): def main():
env = get_player( env = gym.make('CartPole-v1')
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP) action_dim = env.action_space.n
test_env = get_player( state_shape = env.observation_space.shape
args.rom,
image_size=IMAGE_SIZE, rpm = ReplayMemory(MEMORY_SIZE)
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN) model = CartpoleModel(act_dim=action_dim)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA)
act_dim = env.action_space.n agent = CartpoleAgent(
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
agent = AtariAgent(
algorithm, algorithm,
act_dim=act_dim, state_dim=state_shape[0],
start_lr=LEARNING_RATE, act_dim=action_dim,
total_step=args.train_total_steps, e_greed=0.1, # explore
update_freq=UPDATE_FREQ) e_greed_decrement=1e-6
) # probability of exploring is decreasing during training
with tqdm(
total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar: while len(rpm) < MEMORY_WARMUP_SIZE: # warm up replay memory
while rpm.size() < MEMORY_WARMUP_SIZE: run_episode(agent, env, rpm)
total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps) max_episode = 2000
# train # start train
test_flag = 0 episode = 0
pbar = tqdm(total=args.train_total_steps) while episode < max_episode:
total_steps = 0 # train part
max_reward = None for i in range(0, 50):
while total_steps < args.train_total_steps: total_reward = run_episode(agent, env, rpm)
# start epoch episode += 1
total_reward, steps, loss = run_train_episode(env, agent, rpm)
total_steps += steps eval_reward = evaluate(agent, env)
pbar.set_description('[train]exploration:{}'.format(agent.exploration)) logger.info('episode:{} test_reward:{}'.format(
tensorboard.add_scalar('dqn/score', total_reward, total_steps) episode, eval_reward))
tensorboard.add_scalar('dqn/loss', loss,
total_steps) # mean of total loss
tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
pbar.write("testing")
eval_rewards = []
for _ in tqdm(range(3), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
eval_rewards.append(eval_reward)
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, np.mean(eval_rewards)))
eval_test = np.mean(eval_rewards)
tensorboard.add_scalar('dqn/eval', eval_test, total_steps)
pbar.close()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--algo',
default='DQN',
help=
'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively',
)
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main() main()
## Reproduce DQN with PARL
Based on PARL, the DQN algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
Mean episode rewards for 10 million training steps.
<img src=".benchmark/merge.png" width = "1150" height ="230" alt="pong" />
Performance of DQN on various environments
<p align="center">
<img src=".benchmark/table.png" alt="result" width="700"/>
</p>
## How to use
### Dependencies:
+ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ tqdm
+ atari-py
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
### Start Training:
```
# To train a model for Pong game
python train.py --rom ./rom_files/pong.bin
```
> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self):
return self._curr_size
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import paddle.fluid as fluid
import numpy as np
import os
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from datetime import datetime
from replay_memory import ReplayMemory, Experience
from parl.utils import tensorboard, logger
from tqdm import tqdm
from utils import get_player
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 3e-4
def run_train_episode(env, agent, rpm):
total_reward = 0
all_cost = []
state = env.reset()
steps = 0
while True:
steps += 1
context = rpm.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver))
# start training
if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps, np.mean(all_cost)
def run_evaluate_episode(env, agent):
state = env.reset()
total_reward = 0
while True:
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
return total_reward
def main():
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
test_env = get_player(
args.rom,
image_size=IMAGE_SIZE,
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
act_dim = env.action_space.n
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
agent = AtariAgent(
algorithm,
act_dim=act_dim,
start_lr=LEARNING_RATE,
total_step=args.train_total_steps,
update_freq=UPDATE_FREQ)
with tqdm(
total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps)
# train
test_flag = 0
pbar = tqdm(total=args.train_total_steps)
total_steps = 0
max_reward = None
while total_steps < args.train_total_steps:
# start epoch
total_reward, steps, loss = run_train_episode(env, agent, rpm)
total_steps += steps
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
tensorboard.add_scalar('dqn/score', total_reward, total_steps)
tensorboard.add_scalar('dqn/loss', loss,
total_steps) # mean of total loss
tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
pbar.write("testing")
eval_rewards = []
for _ in tqdm(range(3), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
eval_rewards.append(eval_reward)
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, np.mean(eval_rewards)))
eval_test = np.mean(eval_rewards)
tensorboard.add_scalar('dqn/eval', eval_test, total_steps)
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--algo',
default='DQN',
help=
'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively',
)
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册