train.py 4.1 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import gym
import numpy as np
import time
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import DDPG
22
from parl.utils import logger, action_mapping, ReplayMemory
H
Hongsheng Zeng 已提交
23 24

MAX_EPISODES = 5000
25
TEST_EVERY_EPISODES = 20
H
Hongsheng Zeng 已提交
26 27 28 29 30 31 32 33 34 35 36
ACTOR_LR = 1e-4
CRITIC_LR = 1e-3
GAMMA = 0.99
TAU = 0.001
MEMORY_SIZE = int(1e6)
MIN_LEARN_SIZE = 1e4
BATCH_SIZE = 128
REWARD_SCALE = 0.1
ENV_SEED = 1


H
Hongsheng Zeng 已提交
37
def run_train_episode(env, agent, rpm):
H
Hongsheng Zeng 已提交
38 39
    obs = env.reset()
    total_reward = 0
40
    steps = 0
41
    while True:
42
        steps += 1
H
Hongsheng Zeng 已提交
43 44 45 46
        batch_obs = np.expand_dims(obs, axis=0)
        action = agent.predict(batch_obs.astype('float32'))
        action = np.squeeze(action)

H
Hongsheng Zeng 已提交
47 48 49 50
        # Add exploration noise, and clip to [-1.0, 1.0]
        action = np.clip(np.random.normal(action, 1.0), -1.0, 1.0)
        action = action_mapping(action, env.action_space.low[0],
                                env.action_space.high[0])
H
Hongsheng Zeng 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

        next_obs, reward, done, info = env.step(action)

        rpm.append(obs, action, REWARD_SCALE * reward, next_obs, done)

        if rpm.size() > MIN_LEARN_SIZE:
            batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
                BATCH_SIZE)
            agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
                        batch_terminal)

        obs = next_obs
        total_reward += reward

        if done:
            break
67
    return total_reward, steps
H
Hongsheng Zeng 已提交
68 69 70 71 72


def run_evaluate_episode(env, agent):
    obs = env.reset()
    total_reward = 0
73
    while True:
H
Hongsheng Zeng 已提交
74 75 76
        batch_obs = np.expand_dims(obs, axis=0)
        action = agent.predict(batch_obs.astype('float32'))
        action = np.squeeze(action)
H
Hongsheng Zeng 已提交
77 78
        action = action_mapping(action, env.action_space.low[0],
                                env.action_space.high[0])
H
Hongsheng Zeng 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

        next_obs, reward, done, info = env.step(action)

        obs = next_obs
        total_reward += reward

        if done:
            break
    return total_reward


def main():
    env = gym.make(args.env)
    env.seed(ENV_SEED)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

H
Hongsheng Zeng 已提交
97
    model = MujocoModel(act_dim)
H
Hongsheng Zeng 已提交
98 99 100 101 102 103 104 105 106 107 108 109
    algorithm = DDPG(
        model,
        hyperparas={
            'gamma': GAMMA,
            'tau': TAU,
            'actor_lr': ACTOR_LR,
            'critic_lr': CRITIC_LR
        })
    agent = MujocoAgent(algorithm, obs_dim, act_dim)

    rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)

110 111 112 113 114 115 116 117 118 119
    test_flag = 0
    total_steps = 0
    while total_steps < args.train_total_steps:
        train_reward, steps = run_train_episode(env, agent, rpm)
        total_steps += steps
        logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward))

        if total_steps // args.test_every_steps >= test_flag:
            while total_steps // args.test_every_steps >= test_flag:
                test_flag += 1
H
Hongsheng Zeng 已提交
120
            evaluate_reward = run_evaluate_episode(env, agent)
121 122
            logger.info('Steps {}, Evaluate reward: {}'.format(
                total_steps, evaluate_reward))
H
Hongsheng Zeng 已提交
123 124 125 126 127 128


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--env', help='Mujoco environment name', default='HalfCheetah-v2')
129 130 131 132 133 134 135 136 137 138 139
    parser.add_argument(
        '--train_total_steps',
        type=int,
        default=int(1e7),
        help='maximum training steps')
    parser.add_argument(
        '--test_every_steps',
        type=int,
        default=int(1e4),
        help='the step interval between two consecutive evaluations')

H
Hongsheng Zeng 已提交
140
    args = parser.parse_args()
141

H
Hongsheng Zeng 已提交
142
    main()