DQN.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#-*- coding: utf-8 -*-
#File: DQN.py

from agent import Model
import gym
import argparse
from tqdm import tqdm
from expreplay import ReplayMemory, Experience
import numpy as np
import os

UPDATE_FREQ = 4

MEMORY_WARMUP_SIZE = 1000


def run_episode(agent, env, exp, train_or_test):
    assert train_or_test in ['train', 'test'], train_or_test
    total_reward = 0
    state = env.reset()
    for step in range(200):
        action = agent.act(state, train_or_test)
        next_state, reward, isOver, _ = env.step(action)
        if train_or_test == 'train':
            exp.append(Experience(state, action, reward, isOver))
            # train model
            # start training 
            if len(exp) > MEMORY_WARMUP_SIZE:
                batch_idx = np.random.randint(
                    len(exp) - 1, size=(args.batch_size))
                if step % UPDATE_FREQ == 0:
                    batch_state, batch_action, batch_reward, \
                    batch_next_state, batch_isOver = exp.sample(batch_idx)
                    agent.train(batch_state, batch_action, batch_reward, \
                                batch_next_state, batch_isOver)
        total_reward += reward
        state = next_state
        if isOver:
            break
    return total_reward


def train_agent():
    env = gym.make(args.env)
    state_shape = env.observation_space.shape
    exp = ReplayMemory(args.mem_size, state_shape)
    action_dim = env.action_space.n
    agent = Model(state_shape[0], action_dim, gamma=0.99)

    while len(exp) < MEMORY_WARMUP_SIZE:
        run_episode(agent, env, exp, train_or_test='train')

    max_episode = 4000

    # train
    total_episode = 0
    pbar = tqdm(total=max_episode)
    recent_100_reward = []
    for episode in xrange(max_episode):
        # start epoch
        total_reward = run_episode(agent, env, exp, train_or_test='train')
        pbar.set_description('[train]exploration:{}'.format(agent.exploration))
        pbar.update()

        # recent 100 reward
        total_reward = run_episode(agent, env, exp, train_or_test='test')
        recent_100_reward.append(total_reward)
        if len(recent_100_reward) > 100:
            recent_100_reward = recent_100_reward[1:]
        pbar.write("episode:{}    test_reward:{}".format(\
                    episode, np.mean(recent_100_reward)))

    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='MountainCar-v0', \
                        help='enviroment to train DQN model, e.g CartPole-v0')
    parser.add_argument('--gamma', type=float, default=0.99, \
                        help='discount factor for accumulated reward computation')
    parser.add_argument('--mem_size', type=int, default=500000, \
                        help='memory size for experience replay')
    parser.add_argument('--batch_size', type=int, default=192, \
                        help='batch size for training')
    args = parser.parse_args()

    train_agent()