未验证 提交 117b1c38 编写于 作者: L LI Yunxiang 提交者: GitHub

add simple dqn demo (#254)

* add simple dqn

* Update README.md

* Update train.py

* update

* update image in README

* update readme

* simplify

* yapf

* Update README.md

* Update README.md

* Update README.md

* Update train.py

* yapf
上级 0698534b
## Reproduce DQN with PARL
Based on PARL, the DQN algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
Based on PARL, we provide a simple demonstration of DQN.
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Result
### Benchmark result
Performance of DQN playing CartPole-v1
Mean episode rewards for 10 million training steps.
<img src=".benchmark/merge.png" width = "1150" height ="230" alt="pong" />
Performance of DQN on various environments
<p align="center">
<img src=".benchmark/table.png" alt="result" width="700"/>
<p align="left">
<img src="../QuickStart/performance.gif" alt="result" height="175"/>
<img src="cartpole.jpg" alt="result" height="175"/>
</p>
## How to use
......@@ -25,13 +19,14 @@ Performance of DQN on various environments
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ tqdm
+ atari-py
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
### Start Training:
```
# To train a model for Pong game
python train.py --rom ./rom_files/pong.bin
# To train a model for CartPole-v1 game
python train.py
```
> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
## DQN-Variants
For DQN variants such as Double DQN and Dueling DQN, please check [here](https://github.com/PaddlePaddle/PARL/tree/develop/examples/DQN_variant)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
class CartpoleAgent(parl.Agent):
def __init__(self,
algorithm,
state_dim,
act_dim,
e_greed=0.1,
e_greed_decrement=0):
assert isinstance(state_dim, int)
assert isinstance(act_dim, int)
self.state_dim = state_dim
self.act_dim = act_dim
super(CartpoleAgent, self).__init__(algorithm)
self.global_step = 0
self.update_target_steps = 200
self.e_greed = e_greed
self.e_greed_decrement = e_greed_decrement
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32')
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.state_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal,
lr)
def sample(self, obs):
sample = np.random.rand()
if sample < self.e_greed:
act = np.random.randint(self.act_dim)
else:
act = self.predict(obs)
self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal, lr):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal,
'lr': np.float32([lr]),
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class CartpoleModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 128
hid2_size = 128
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act=None)
def value(self, obs):
h1 = self.fc1(obs)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
......@@ -12,103 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
from collections import deque, namedtuple
# Modified from https://github.com/seungeunrho/minimalRL/blob/master/dqn.py
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
import random
import collections
import numpy as np
class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def __init__(self, max_size):
self.buffer = collections.deque(maxlen=max_size)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
self.buffer.append(exp)
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size)
state_batch, action_batch, reward_batch, next_state_batch, done_batch = [], [], [], [], []
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
for experience in mini_batch:
s, a, r, s_p, done = experience
state_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_state_batch.append(s_p)
done_batch.append(done)
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
return np.array(state_batch).astype('float32'), \
np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_state_batch).astype('float32'), np.array(done_batch).astype('float32')
def __len__(self):
return self._curr_size
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
return len(self.buffer)
......@@ -12,160 +12,100 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import paddle.fluid as fluid
import numpy as np
import os
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from datetime import datetime
from replay_memory import ReplayMemory, Experience
from parl.utils import tensorboard, logger
from tqdm import tqdm
from utils import get_player
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 3e-4
def run_train_episode(env, agent, rpm):
from parl.utils import logger
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
from replay_memory import ReplayMemory
LEARN_FREQ = 5 # update parameters every 5 steps
MEMORY_SIZE = 20000 # replay memory size
MEMORY_WARMUP_SIZE = 200 # store some experiences in the replay memory in advance
BATCH_SIZE = 32
LEARNING_RATE = 0.0005
GAMMA = 0.99 # discount factor of reward
def run_episode(agent, env, rpm):
total_reward = 0
all_cost = []
state = env.reset()
steps = 0
step = 0
while True:
steps += 1
context = rpm.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
step += 1
action = agent.sample(state)
next_state, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver))
# start training
if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps, np.mean(all_cost)
rpm.append((state, action, reward, next_state, isOver))
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_state, batch_action, batch_reward, batch_next_state,
batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver,
LEARNING_RATE)
def run_evaluate_episode(env, agent):
state = env.reset()
total_reward = 0
while True:
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
state = next_state
if isOver:
break
return total_reward
def evaluate(agent, env, render=False):
# test part, run 5 episodes and average
eval_reward = []
for i in range(5):
state = env.reset()
episode_reward = 0
isOver = False
while not isOver:
action = agent.predict(state)
if render:
env.render()
state, reward, isOver, _ = env.step(action)
episode_reward += reward
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main():
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
test_env = get_player(
args.rom,
image_size=IMAGE_SIZE,
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
act_dim = env.action_space.n
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
agent = AtariAgent(
env = gym.make('CartPole-v1')
action_dim = env.action_space.n
state_shape = env.observation_space.shape
rpm = ReplayMemory(MEMORY_SIZE)
model = CartpoleModel(act_dim=action_dim)
algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA)
agent = CartpoleAgent(
algorithm,
act_dim=act_dim,
start_lr=LEARNING_RATE,
total_step=args.train_total_steps,
update_freq=UPDATE_FREQ)
with tqdm(
total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps)
# train
test_flag = 0
pbar = tqdm(total=args.train_total_steps)
total_steps = 0
max_reward = None
while total_steps < args.train_total_steps:
# start epoch
total_reward, steps, loss = run_train_episode(env, agent, rpm)
total_steps += steps
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
tensorboard.add_scalar('dqn/score', total_reward, total_steps)
tensorboard.add_scalar('dqn/loss', loss,
total_steps) # mean of total loss
tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
pbar.write("testing")
eval_rewards = []
for _ in tqdm(range(3), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
eval_rewards.append(eval_reward)
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, np.mean(eval_rewards)))
eval_test = np.mean(eval_rewards)
tensorboard.add_scalar('dqn/eval', eval_test, total_steps)
pbar.close()
state_dim=state_shape[0],
act_dim=action_dim,
e_greed=0.1, # explore
e_greed_decrement=1e-6
) # probability of exploring is decreasing during training
while len(rpm) < MEMORY_WARMUP_SIZE: # warm up replay memory
run_episode(agent, env, rpm)
max_episode = 2000
# start train
episode = 0
while episode < max_episode:
# train part
for i in range(0, 50):
total_reward = run_episode(agent, env, rpm)
episode += 1
eval_reward = evaluate(agent, env)
logger.info('episode:{} test_reward:{}'.format(
episode, eval_reward))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--algo',
default='DQN',
help=
'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively',
)
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
## Reproduce DQN with PARL
Based on PARL, the DQN algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
Mean episode rewards for 10 million training steps.
<img src=".benchmark/merge.png" width = "1150" height ="230" alt="pong" />
Performance of DQN on various environments
<p align="center">
<img src=".benchmark/table.png" alt="result" width="700"/>
</p>
## How to use
### Dependencies:
+ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ tqdm
+ atari-py
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
### Start Training:
```
# To train a model for Pong game
python train.py --rom ./rom_files/pong.bin
```
> To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self):
return self._curr_size
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import paddle.fluid as fluid
import numpy as np
import os
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from datetime import datetime
from replay_memory import ReplayMemory, Experience
from parl.utils import tensorboard, logger
from tqdm import tqdm
from utils import get_player
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 3e-4
def run_train_episode(env, agent, rpm):
total_reward = 0
all_cost = []
state = env.reset()
steps = 0
while True:
steps += 1
context = rpm.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver))
# start training
if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps, np.mean(all_cost)
def run_evaluate_episode(env, agent):
state = env.reset()
total_reward = 0
while True:
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
return total_reward
def main():
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
test_env = get_player(
args.rom,
image_size=IMAGE_SIZE,
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
act_dim = env.action_space.n
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
agent = AtariAgent(
algorithm,
act_dim=act_dim,
start_lr=LEARNING_RATE,
total_step=args.train_total_steps,
update_freq=UPDATE_FREQ)
with tqdm(
total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps)
# train
test_flag = 0
pbar = tqdm(total=args.train_total_steps)
total_steps = 0
max_reward = None
while total_steps < args.train_total_steps:
# start epoch
total_reward, steps, loss = run_train_episode(env, agent, rpm)
total_steps += steps
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
tensorboard.add_scalar('dqn/score', total_reward, total_steps)
tensorboard.add_scalar('dqn/loss', loss,
total_steps) # mean of total loss
tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
pbar.write("testing")
eval_rewards = []
for _ in tqdm(range(3), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
eval_rewards.append(eval_reward)
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, np.mean(eval_rewards)))
eval_test = np.mean(eval_rewards)
tensorboard.add_scalar('dqn/eval', eval_test, total_steps)
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--algo',
default='DQN',
help=
'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively',
)
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册