play.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#-*- coding: utf-8 -*-

import argparse
import os
import numpy as np
import paddle.fluid as fluid

from train import get_player
from tqdm import tqdm


def predict_action(exe, state, predict_program, feed_names, fetch_targets,
                   action_dim):
Z
zenghsh3 已提交
14
    if np.random.random() < 0.01:
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        act = np.random.randint(action_dim)
    else:
        state = np.expand_dims(state, axis=0)
        pred_Q = exe.run(predict_program,
                         feed={feed_names[0]: state.astype('float32')},
                         fetch_list=fetch_targets)[0]
        pred_Q = np.squeeze(pred_Q, axis=0)
        act = np.argmax(pred_Q)
    return act


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--use_cuda', action='store_true', help='if set, use cuda')
    parser.add_argument('--rom', type=str, required=True, help='atari rom')
    parser.add_argument(
        '--model_path', type=str, required=True, help='dirname to load model')
    parser.add_argument(
        '--viz',
        type=float,
        default=0,
        help='''viz: visualization setting:
                Set to 0 to disable;
                Set to a positive number to be the delay between frames to show.
             ''')
    args = parser.parse_args()

    env = get_player(args.rom, viz=args.viz)

    place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    inference_scope = fluid.core.Scope()
    with fluid.scope_guard(inference_scope):
        [predict_program, feed_names,
         fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)

        episode_reward = []
        for _ in tqdm(xrange(30), desc='eval agent'):
            state = env.reset()
            total_reward = 0
            while True:
                action = predict_action(exe, state, predict_program, feed_names,
                                        fetch_targets, env.action_space.n)
                state, reward, isOver, info = env.step(action)
                total_reward += reward
                if isOver:
                    break
            episode_reward.append(total_reward)
        eval_reward = np.mean(episode_reward)
        print('Average reward of 30 epidose: {}'.format(eval_reward))