train.py 3.7 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import gym
import numpy as np
import time
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import DDPG
H
Hongsheng Zeng 已提交
22
from parl.utils import logger, action_mapping
H
Hongsheng Zeng 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
from replay_memory import ReplayMemory

MAX_EPISODES = 5000
TEST_EVERY_EPISODES = 50
MAX_STEPS_EACH_EPISODE = 1000
ACTOR_LR = 1e-4
CRITIC_LR = 1e-3
GAMMA = 0.99
TAU = 0.001
MEMORY_SIZE = int(1e6)
MIN_LEARN_SIZE = 1e4
BATCH_SIZE = 128
REWARD_SCALE = 0.1
ENV_SEED = 1


H
Hongsheng Zeng 已提交
39
def run_train_episode(env, agent, rpm):
H
Hongsheng Zeng 已提交
40 41 42 43 44 45 46
    obs = env.reset()
    total_reward = 0
    for j in range(MAX_STEPS_EACH_EPISODE):
        batch_obs = np.expand_dims(obs, axis=0)
        action = agent.predict(batch_obs.astype('float32'))
        action = np.squeeze(action)

H
Hongsheng Zeng 已提交
47 48 49 50
        # Add exploration noise, and clip to [-1.0, 1.0]
        action = np.clip(np.random.normal(action, 1.0), -1.0, 1.0)
        action = action_mapping(action, env.action_space.low[0],
                                env.action_space.high[0])
H
Hongsheng Zeng 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

        next_obs, reward, done, info = env.step(action)

        rpm.append(obs, action, REWARD_SCALE * reward, next_obs, done)

        if rpm.size() > MIN_LEARN_SIZE:
            batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
                BATCH_SIZE)
            agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
                        batch_terminal)

        obs = next_obs
        total_reward += reward

        if done:
            break
    return total_reward


def run_evaluate_episode(env, agent):
    obs = env.reset()
    total_reward = 0
    for j in range(MAX_STEPS_EACH_EPISODE):
        batch_obs = np.expand_dims(obs, axis=0)
        action = agent.predict(batch_obs.astype('float32'))
        action = np.squeeze(action)
H
Hongsheng Zeng 已提交
77 78
        action = action_mapping(action, env.action_space.low[0],
                                env.action_space.high[0])
H
Hongsheng Zeng 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

        next_obs, reward, done, info = env.step(action)

        obs = next_obs
        total_reward += reward

        if done:
            break
    return total_reward


def main():
    env = gym.make(args.env)
    env.seed(ENV_SEED)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

H
Hongsheng Zeng 已提交
97
    model = MujocoModel(act_dim)
H
Hongsheng Zeng 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110
    algorithm = DDPG(
        model,
        hyperparas={
            'gamma': GAMMA,
            'tau': TAU,
            'actor_lr': ACTOR_LR,
            'critic_lr': CRITIC_LR
        })
    agent = MujocoAgent(algorithm, obs_dim, act_dim)

    rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)

    for i in range(MAX_EPISODES):
H
Hongsheng Zeng 已提交
111
        train_reward = run_train_episode(env, agent, rpm)
H
Hongsheng Zeng 已提交
112 113 114 115 116 117 118 119 120 121 122 123
        logger.info('Episode: {} Reward: {}'.format(i, train_reward))
        if (i + 1) % TEST_EVERY_EPISODES == 0:
            evaluate_reward = run_evaluate_episode(env, agent)
            logger.info('Evaluate Reward: {}'.format(evaluate_reward))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--env', help='Mujoco environment name', default='HalfCheetah-v2')
    args = parser.parse_args()
    main()