train.py 2.2 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gym
import numpy as np
B
Bo Zhou 已提交
17
import parl
H
Hongsheng Zeng 已提交
18 19 20
from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.utils import logger
H
Hongsheng Zeng 已提交
21
from utils import calc_discount_norm_reward
H
Hongsheng Zeng 已提交
22 23 24 25 26 27 28

OBS_DIM = 4
ACT_DIM = 2
GAMMA = 0.99
LEARNING_RATE = 1e-3


B
Bo Zhou 已提交
29
def run_episode(env, agent, train_or_test='train'):
H
Hongsheng Zeng 已提交
30 31 32 33
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
B
Bo Zhou 已提交
34 35 36 37
        if train_or_test == 'train':
            action = agent.sample(obs)
        else:
            action = agent.predict(obs)
H
Hongsheng Zeng 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list


def main():
    env = gym.make("CartPole-v0")
    model = CartpoleModel(act_dim=ACT_DIM)
B
Bo Zhou 已提交
51
    alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
B
Bo Zhou 已提交
52
    agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
H
Hongsheng Zeng 已提交
53

H
Hongsheng Zeng 已提交
54
    for i in range(1000):
B
Bo Zhou 已提交
55
        obs_list, action_list, reward_list = run_episode(env, agent)
H
Hongsheng Zeng 已提交
56 57 58 59
        logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))

        batch_obs = np.array(obs_list)
        batch_action = np.array(action_list)
H
Hongsheng Zeng 已提交
60
        batch_reward = calc_discount_norm_reward(reward_list, GAMMA)
H
Hongsheng Zeng 已提交
61 62

        agent.learn(batch_obs, batch_action, batch_reward)
63
        if (i + 1) % 100 == 0:
B
Bo Zhou 已提交
64 65 66
            _, _, reward_list = run_episode(env, agent, train_or_test='test')
            total_reward = np.sum(reward_list)
            logger.info('Test reward: {}'.format(total_reward))
H
Hongsheng Zeng 已提交
67 68 69 70


if __name__ == '__main__':
    main()