train.py 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
#-*- coding: utf-8 -*-

from DQN_agent import DQNModel
from DoubleDQN_agent import DoubleDQNModel
from DuelingDQN_agent import DuelingDQNModel
from atari import AtariPlayer
import paddle.fluid as fluid
import gym
import argparse
import cv2
from tqdm import tqdm
from expreplay import ReplayMemory, Experience
import numpy as np
import os

from datetime import datetime
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from collections import deque

UPDATE_FREQ = 4

#MEMORY_WARMUP_SIZE = 2000
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
ACTION_REPEAT = 4  # aka FRAME_SKIP
UPDATE_FREQ = 4


def run_train_episode(agent, env, exp):
    total_reward = 0
    state = env.reset()
    step = 0
    while True:
        step += 1
        context = exp.recent_state()
        context.append(state)
        context = np.stack(context, axis=0)
        action = agent.act(context, train_or_test='train')
        next_state, reward, isOver, _ = env.step(action)
        exp.append(Experience(state, action, reward, isOver))
        # train model
        # start training 
        if len(exp) > MEMORY_WARMUP_SIZE:
            if step % UPDATE_FREQ == 0:
                batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch(
                    args.batch_size)
                batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
                batch_next_state = batch_all_state[:, 1:, :, :]
                agent.train(batch_state, batch_action, batch_reward,
                            batch_next_state, batch_isOver)
        total_reward += reward
        state = next_state
        if isOver:
            break
    return total_reward, step


def get_player(rom, viz=False, train=False):
    env = AtariPlayer(
        rom,
        frame_skip=ACTION_REPEAT,
        viz=viz,
        live_lost_as_eoe=train,
        max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    if not train:
        # in training, context is taken care of in expreplay buffer
        env = FrameStack(env, CONTEXT_LEN)
    return env


def eval_agent(agent, env):
    episode_reward = []
    for _ in tqdm(xrange(30), desc='eval agent'):
        state = env.reset()
        total_reward = 0
        step = 0
        while True:
            step += 1
            action = agent.act(state, train_or_test='test')
            state, reward, isOver, info = env.step(action)
            total_reward += reward
            if isOver:
                break
        episode_reward.append(total_reward)
    eval_reward = np.mean(episode_reward)
    return eval_reward


def train_agent():
    env = get_player(args.rom, train=True)
    test_env = get_player(args.rom)
    exp = ReplayMemory(args.mem_size, IMAGE_SIZE, CONTEXT_LEN)
    action_dim = env.action_space.n

    if args.alg == 'DQN':
        agent = DQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
                         args.use_cuda)
    elif args.alg == 'DoubleDQN':
        agent = DoubleDQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
                               args.use_cuda)
    elif args.alg == 'DuelingDQN':
        agent = DuelingDQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
                                args.use_cuda)
    else:
        print('Input algorithm name error!')
        return

    with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
        while len(exp) < MEMORY_WARMUP_SIZE:
            total_reward, step = run_train_episode(agent, env, exp)
            pbar.update(step)

    # train
    test_flag = 0
    save_flag = 0
    pbar = tqdm(total=1e8)
    recent_100_reward = []
    total_step = 0
123 124
    max_reward = None
    save_path = os.path.join(args.model_dirname, '{}-{}'.format(
Z
zenghsh3 已提交
125
        args.alg, os.path.basename(args.rom).split('.')[0]))
126 127 128 129 130 131 132 133 134 135 136 137 138
    while True:
        # start epoch
        total_reward, step = run_train_episode(agent, env, exp)
        total_step += step
        pbar.set_description('[train]exploration:{}'.format(agent.exploration))
        pbar.update(step)

        if total_step // args.test_every_steps == test_flag:
            pbar.write("testing")
            eval_reward = eval_agent(agent, test_env)
            test_flag += 1
            print("eval_agent done, (steps, eval_reward): ({}, {})".format(
                total_step, eval_reward))
Z
zenghsh3 已提交
139

140 141 142 143 144
            if max_reward is None or eval_reward > max_reward:
                max_reward = eval_reward
                fluid.io.save_inference_model(save_path, ['state'],
                                              agent.pred_value, agent.exe,
                                              agent.predict_program)
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--alg',
        type=str,
        default='DQN',
        help='Reinforcement learning algorithm, support: DQN, DoubleDQN, DuelingDQN'
    )
    parser.add_argument(
        '--use_cuda', action='store_true', help='if set, use cuda')
    parser.add_argument(
        '--gamma',
        type=float,
        default=0.99,
        help='discount factor for accumulated reward computation')
    parser.add_argument(
        '--mem_size',
        type=int,
        default=1000000,
        help='memory size for experience replay')
    parser.add_argument(
        '--batch_size', type=int, default=64, help='batch size for training')
    parser.add_argument('--rom', help='atari rom', required=True)
    parser.add_argument(
        '--model_dirname',
        type=str,
        default='saved_model',
        help='dirname to save model')
    parser.add_argument(
        '--test_every_steps',
        type=int,
        default=100000,
        help='every steps number to run test')
    args = parser.parse_args()
    train_agent()