train.py 4.7 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import gym
import paddle.fluid as fluid
import numpy as np
import os
from atari_agent import AtariAgent
from atari_model import AtariModel
from collections import deque
from datetime import datetime
H
Hongsheng Zeng 已提交
24
from replay_memory import ReplayMemory, Experience
H
Hongsheng Zeng 已提交
25
from parl.algorithms import DQN
H
Hongsheng Zeng 已提交
26 27
from parl.utils import logger
from tqdm import tqdm
H
Hongsheng Zeng 已提交
28
from utils import get_player
H
Hongsheng Zeng 已提交
29 30 31 32 33

MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
H
Hongsheng Zeng 已提交
34
FRAME_SKIP = 4
H
Hongsheng Zeng 已提交
35 36
UPDATE_FREQ = 4
GAMMA = 0.99
H
Hongsheng Zeng 已提交
37
LEARNING_RATE = 1e-3 * 0.5
H
Hongsheng Zeng 已提交
38 39


H
Hongsheng Zeng 已提交
40
def run_train_episode(env, agent, rpm):
H
Hongsheng Zeng 已提交
41 42 43 44 45 46
    total_reward = 0
    all_cost = []
    state = env.reset()
    step = 0
    while True:
        step += 1
H
Hongsheng Zeng 已提交
47
        context = rpm.recent_state()
H
Hongsheng Zeng 已提交
48 49 50 51
        context.append(state)
        context = np.stack(context, axis=0)
        action = agent.sample(context)
        next_state, reward, isOver, _ = env.step(action)
H
Hongsheng Zeng 已提交
52
        rpm.append(Experience(state, action, reward, isOver))
H
Hongsheng Zeng 已提交
53
        # start training
H
Hongsheng Zeng 已提交
54
        if rpm.size() > MEMORY_WARMUP_SIZE:
H
Hongsheng Zeng 已提交
55
            if step % UPDATE_FREQ == 0:
H
Hongsheng Zeng 已提交
56
                batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
H
Hongsheng Zeng 已提交
57 58 59 60 61 62 63 64 65 66
                    args.batch_size)
                batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
                batch_next_state = batch_all_state[:, 1:, :, :]
                cost = agent.learn(batch_state, batch_action, batch_reward,
                                   batch_next_state, batch_isOver)
                all_cost.append(float(cost))
        total_reward += reward
        state = next_state
        if isOver:
            break
67 68 69
    if all_cost:
        logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
            total_reward, np.mean(all_cost)))
H
Hongsheng Zeng 已提交
70 71 72
    return total_reward, step


H
Hongsheng Zeng 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
def run_evaluate_episode(env, agent):
    state = env.reset()
    total_reward = 0
    while True:
        action = agent.predict(state)
        state, reward, isOver, info = env.step(action)
        total_reward += reward
        if isOver:
            break
    return total_reward


def main():
    env = get_player(
        args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
    test_env = get_player(
        args.rom,
        image_size=IMAGE_SIZE,
        frame_skip=FRAME_SKIP,
        context_len=CONTEXT_LEN)
    rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
H
Hongsheng Zeng 已提交
94 95 96 97 98 99 100
    action_dim = env.action_space.n

    hyperparas = {
        'action_dim': action_dim,
        'lr': LEARNING_RATE,
        'gamma': GAMMA
    }
H
Hongsheng Zeng 已提交
101
    model = AtariModel(action_dim)
H
Hongsheng Zeng 已提交
102
    algorithm = DQN(model, hyperparas)
H
Hongsheng Zeng 已提交
103 104 105
    agent = AtariAgent(algorithm, action_dim)

    with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
H
Hongsheng Zeng 已提交
106 107
        while rpm.size() < MEMORY_WARMUP_SIZE:
            total_reward, step = run_train_episode(env, agent, rpm)
H
Hongsheng Zeng 已提交
108 109 110 111 112 113 114 115 116 117
            pbar.update(step)

    # train
    test_flag = 0
    pbar = tqdm(total=1e8)
    recent_100_reward = []
    total_step = 0
    max_reward = None
    while True:
        # start epoch
H
Hongsheng Zeng 已提交
118
        total_reward, step = run_train_episode(env, agent, rpm)
H
Hongsheng Zeng 已提交
119 120 121 122 123 124
        total_step += step
        pbar.set_description('[train]exploration:{}'.format(agent.exploration))
        pbar.update(step)

        if total_step // args.test_every_steps == test_flag:
            pbar.write("testing")
H
Hongsheng Zeng 已提交
125
            eval_rewards = []
126
            for _ in tqdm(range(3), desc='eval agent'):
H
Hongsheng Zeng 已提交
127 128
                eval_reward = run_evaluate_episode(test_env, agent)
                eval_rewards.append(eval_reward)
H
Hongsheng Zeng 已提交
129 130 131
            test_flag += 1
            logger.info(
                "eval_agent done, (steps, eval_reward): ({}, {})".format(
H
Hongsheng Zeng 已提交
132 133 134
                    total_step, np.mean(eval_rewards)))
        if total_step > 1e8:
            break
H
Hongsheng Zeng 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rom', help='atari rom', required=True)
    parser.add_argument(
        '--batch_size', type=int, default=64, help='batch size for training')
    parser.add_argument(
        '--test_every_steps',
        type=int,
        default=100000,
        help='every steps number to run test')
    args = parser.parse_args()
H
Hongsheng Zeng 已提交
150
    main()