提交 f8de849b 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

add PPO example (#39)

* add PPO example

* Update Readme

* Update Readme

* fix codestyle

* Update Readme

* refine action mapping

* add more unitest case

* remove unnecessary params initialize, add more comments, add benchmark result

* rename

* remove PARL dependence in readme of examples
上级 bd37f473
......@@ -2,6 +2,8 @@
<img src=".github/PARL-logo.png" alt="PARL" width="500"/>
</p>
> PARL is a flexible and high-efficient reinforcement learning framework based on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle).
# Features
**Reproducible**. We provide algorithms that stably reproduce the result of many influential reinforcement learning algorithms
......@@ -76,5 +78,5 @@ pip install --upgrade git+https://github.com/PaddlePaddle/PARL.git
- [QuickStart](examples/QuickStart/)
- [DQN](examples/DQN/)
- [DDPG](examples/DDPG/)
- PPO
- [PPO](examples/PPO/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
......@@ -11,7 +11,6 @@ Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco
## How to use
### Dependencies:
+ python2.7 or python3.5+
+ [PARL](https://github.com/PaddlePaddle/PARL)
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
+ tqdm
......
......@@ -18,8 +18,8 @@ from parl.framework.model_base import Model
class MujocoModel(Model):
def __init__(self, act_dim, act_bound):
self.actor_model = ActorModel(act_dim, act_bound)
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
def policy(self, obs):
......@@ -33,8 +33,7 @@ class MujocoModel(Model):
class ActorModel(Model):
def __init__(self, act_dim, act_bound):
self.act_bound = act_bound
def __init__(self, act_dim):
hid1_size = 400
hid2_size = 300
......@@ -46,7 +45,7 @@ class ActorModel(Model):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.fc3(hid2)
means = means * self.act_bound
means = means
return means
......
......@@ -19,7 +19,7 @@ import time
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import DDPG
from parl.utils import logger
from parl.utils import logger, action_mapping
from replay_memory import ReplayMemory
MAX_EPISODES = 5000
......@@ -36,7 +36,7 @@ REWARD_SCALE = 0.1
ENV_SEED = 1
def run_train_episode(env, agent, rpm, act_bound):
def run_train_episode(env, agent, rpm):
obs = env.reset()
total_reward = 0
for j in range(MAX_STEPS_EACH_EPISODE):
......@@ -44,9 +44,10 @@ def run_train_episode(env, agent, rpm, act_bound):
action = agent.predict(batch_obs.astype('float32'))
action = np.squeeze(action)
# Add exploration noise
action = np.clip(
np.random.normal(action, act_bound), -act_bound, act_bound)
# Add exploration noise, and clip to [-1.0, 1.0]
action = np.clip(np.random.normal(action, 1.0), -1.0, 1.0)
action = action_mapping(action, env.action_space.low[0],
env.action_space.high[0])
next_obs, reward, done, info = env.step(action)
......@@ -73,6 +74,8 @@ def run_evaluate_episode(env, agent):
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
action = np.squeeze(action)
action = action_mapping(action, env.action_space.low[0],
env.action_space.high[0])
next_obs, reward, done, info = env.step(action)
......@@ -90,9 +93,8 @@ def main():
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_bound = env.action_space.high[0]
model = MujocoModel(act_dim, act_bound)
model = MujocoModel(act_dim)
algorithm = DDPG(
model,
hyperparas={
......@@ -106,7 +108,7 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
for i in range(MAX_EPISODES):
train_reward = run_train_episode(env, agent, rpm, act_bound)
train_reward = run_train_episode(env, agent, rpm)
logger.info('Episode: {} Reward: {}'.format(i, train_reward))
if (i + 1) % TEST_EVERY_EPISODES == 0:
evaluate_reward = run_evaluate_episode(env, agent)
......
......@@ -11,7 +11,6 @@ Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari g
## How to use
### Dependencies:
+ python2.7 or python3.5+
+ [PARL](https://github.com/PaddlePaddle/PARL)
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
+ tqdm
......
......@@ -5,7 +5,6 @@ This folder will contains the code used to train the winning models for the [Neu
### Dependencies
- python3.6
- [paddlepaddle>=1.2.0](https://github.com/PaddlePaddle/Paddle)
- [PARL](https://github.com/PaddlePaddle/PARL)
- [osim-rl](https://github.com/stanfordnmbl/osim-rl)
### Start Testing best models
......
## Reproduce PPO with PARL
Based on PARL, the PPO model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Mujoco game.
Include following approach:
+ Clipped Surrogate Objective
+ Adaptive KL Penalty Coefficient
> PPO in
[Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)
### Mujoco games introduction
Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco game.
### Benchmark result
- HalfCheetah-v2
<img src=".benchmark/PPO_HalfCheetah-v2.png"/>
## How to use
### Dependencies:
+ python2.7 or python3.5+
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
+ tqdm
+ mujoco-py>=1.50.1.0
### Start Training:
```
# To train an agent for HalfCheetah-v2 game (default: CLIP loss)
python train.py
# To train for different game and different loss type
# python train.py --env [ENV_NAME] --loss_type [CLIP|KLPEN]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from paddle import fluid
from sklearn.utils import shuffle
from parl.framework.agent_base import Agent
from parl.utils import logger
class MujocoAgent(Agent):
def __init__(self,
algorithm,
obs_dim,
act_dim,
kl_targ,
loss_type,
beta=1.0,
epsilon=0.2,
policy_learn_times=20,
value_learn_times=10,
value_batch_size=256):
self.alg = algorithm
self.obs_dim = obs_dim
self.act_dim = act_dim
assert loss_type == 'CLIP' or loss_type == 'KLPEN'
self.loss_type = loss_type
super(MujocoAgent, self).__init__(algorithm)
self.policy_learn_times = policy_learn_times
# Adaptive kl penalty coefficient
self.beta = beta
self.kl_targ = kl_targ
self.value_learn_times = value_learn_times
self.value_batch_size = value_batch_size
self.value_learn_buffer = None
def build_program(self):
self.policy_predict_program = fluid.Program()
self.policy_sample_program = fluid.Program()
self.policy_learn_program = fluid.Program()
self.value_predict_program = fluid.Program()
self.value_learn_program = fluid.Program()
with fluid.program_guard(self.policy_sample_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
sampled_act = self.alg.define_sample(obs)
self.policy_sample_output = [sampled_act]
with fluid.program_guard(self.policy_predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
means = self.alg.define_predict(obs)
self.policy_predict_output = [means]
with fluid.program_guard(self.policy_learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
actions = layers.data(
name='actions', shape=[self.act_dim], dtype='float32')
advantages = layers.data(
name='advantages', shape=[1], dtype='float32')
if self.loss_type == 'KLPEN':
beta = layers.data(name='beta', shape=[], dtype='float32')
loss, kl = self.alg.define_policy_learn(
obs, actions, advantages, beta)
else:
loss, kl = self.alg.define_policy_learn(
obs, actions, advantages)
self.policy_learn_output = [loss, kl]
with fluid.program_guard(self.value_predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
value = self.alg.define_value_predict(obs)
self.value_predict_output = [value]
with fluid.program_guard(self.value_learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
val = layers.data(name='val', shape=[], dtype='float32')
value_loss = self.alg.define_value_learn(obs, val)
self.value_learn_output = [value_loss]
def policy_sample(self, obs):
feed = {'obs': obs}
sampled_act = self.fluid_executor.run(
self.policy_sample_program,
feed=feed,
fetch_list=self.policy_sample_output)[0]
return sampled_act
def policy_predict(self, obs):
feed = {'obs': obs}
means = self.fluid_executor.run(
self.policy_predict_program,
feed=feed,
fetch_list=self.policy_predict_output)[0]
return means
def value_predict(self, obs):
feed = {'obs': obs}
value = self.fluid_executor.run(
self.value_predict_program,
feed=feed,
fetch_list=self.value_predict_output)[0]
return value
def _batch_policy_learn(self, obs, actions, advantages):
if self.loss_type == 'KLPEN':
feed = {
'obs': obs,
'actions': actions,
'advantages': advantages,
'beta': self.beta
}
else:
feed = {'obs': obs, 'actions': actions, 'advantages': advantages}
[loss, kl] = self.fluid_executor.run(
self.policy_learn_program,
feed=feed,
fetch_list=self.policy_learn_output)
return loss, kl
def _batch_value_learn(self, obs, val):
feed = {'obs': obs, 'val': val}
value_loss = self.fluid_executor.run(
self.value_learn_program,
feed=feed,
fetch_list=self.value_learn_output)[0]
return value_loss
def policy_learn(self, obs, actions, advantages):
""" Learn policy:
1. Sync parameters of policy model to old policy model
2. Fix old policy model, and learn policy model multi times
3. if use KLPEN loss, Adjust kl loss coefficient: beta
"""
self.alg.sync_old_policy(self.gpu_id)
all_loss, all_kl = [], []
for _ in range(self.policy_learn_times):
loss, kl = self._batch_policy_learn(obs, actions, advantages)
all_loss.append(loss)
all_kl.append(kl)
if self.loss_type == 'KLPEN':
# Adative KL penalty coefficient
if kl > self.kl_targ * 2:
self.beta = 1.5 * self.beta
elif kl < self.kl_targ / 2:
self.beta = self.beta / 1.5
return np.mean(all_loss), np.mean(all_kl)
def value_learn(self, obs, value):
""" Fit model to current data batch + previous data batch
"""
data_size = obs.shape[0]
if self.value_learn_buffer is None:
obs_train, value_train = obs, value
else:
obs_train = np.concatenate([obs, self.value_learn_buffer[0]])
value_train = np.concatenate([value, self.value_learn_buffer[1]])
self.value_learn_buffer = (obs, value)
all_loss = []
for _ in range(self.value_learn_times):
obs_train, value_train = shuffle(obs_train, value_train)
start = 0
while start < data_size:
end = start + self.value_batch_size
value_loss = self._batch_value_learn(obs_train[start:end, :],
value_train[start:end])
all_loss.append(value_loss)
start += self.value_batch_size
return np.mean(all_loss)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Model
class MujocoModel(Model):
def __init__(self, obs_dim, act_dim, init_logvar=-1.0):
self.policy_model = PolicyModel(obs_dim, act_dim, init_logvar)
self.value_model = ValueModel(obs_dim, act_dim)
self.policy_lr = self.policy_model.lr
self.value_lr = self.value_model.lr
def policy(self, obs):
return self.policy_model.policy(obs)
def policy_sample(self, obs):
return self.policy_model.sample(obs)
def value(self, obs):
return self.value_model.value(obs)
class PolicyModel(Model):
def __init__(self, obs_dim, act_dim, init_logvar):
self.obs_dim = obs_dim
self.act_dim = act_dim
hid1_size = obs_dim * 10
hid3_size = act_dim * 10
hid2_size = int(np.sqrt(hid1_size * hid3_size))
self.lr = 9e-4 / np.sqrt(hid2_size)
self.fc1 = layers.fc(size=hid1_size, act='tanh')
self.fc2 = layers.fc(size=hid2_size, act='tanh')
self.fc3 = layers.fc(size=hid3_size, act='tanh')
self.fc4 = layers.fc(size=act_dim, act='tanh')
self.logvars = layers.create_parameter(
shape=[act_dim],
dtype='float32',
default_initializer=fluid.initializer.ConstantInitializer(
init_logvar))
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
hid3 = self.fc3(hid2)
means = self.fc4(hid3)
logvars = self.logvars()
return means, logvars
def sample(self, obs):
means, logvars = self.policy(obs)
sampled_act = means + (
layers.exp(logvars / 2.0) * # stddev
layers.gaussian_random(shape=(self.act_dim, ), dtype='float32'))
return sampled_act
class ValueModel(Model):
def __init__(self, obs_dim, act_dim):
super(ValueModel, self).__init__()
hid1_size = obs_dim * 10
hid3_size = 5
hid2_size = int(np.sqrt(hid1_size * hid3_size))
self.lr = 1e-2 / np.sqrt(hid2_size)
self.fc1 = layers.fc(size=hid1_size, act='tanh')
self.fc2 = layers.fc(size=hid2_size, act='tanh')
self.fc3 = layers.fc(size=hid3_size, act='tanh')
self.fc4 = layers.fc(size=1)
def value(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
hid3 = self.fc3(hid2)
V = self.fc4(hid3)
V = layers.squeeze(V, axes=[])
return V
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import numpy as np
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import PPO
from parl.utils import logger, action_mapping
from utils import *
def run_train_episode(env, agent, scaler):
obs = env.reset()
observes, actions, rewards, unscaled_obs = [], [], [], []
done = False
step = 0.0
scale, offset = scaler.get()
scale[-1] = 1.0 # don't scale time step feature
offset[-1] = 0.0 # don't offset time step feature
while not done:
obs = obs.reshape((1, -1))
obs = np.append(obs, [[step]], axis=1) # add time step feature
unscaled_obs.append(obs)
obs = (obs - offset) * scale # center and scale observations
obs = obs.astype('float32')
observes.append(obs)
action = agent.policy_sample(obs)
action = np.clip(action, -1.0, 1.0)
action = action_mapping(action, env.action_space.low[0],
env.action_space.high[0])
action = action.reshape((1, -1)).astype('float32')
actions.append(action)
obs, reward, done, _ = env.step(np.squeeze(action))
rewards.append(reward)
step += 1e-3 # increment time step feature
return (np.concatenate(observes), np.concatenate(actions),
np.array(rewards, dtype='float32'), np.concatenate(unscaled_obs))
def run_evaluate_episode(env, agent, scaler):
obs = env.reset()
rewards = []
step = 0.0
scale, offset = scaler.get()
scale[-1] = 1.0 # don't scale time step feature
offset[-1] = 0.0 # don't offset time step feature
while True:
obs = obs.reshape((1, -1))
obs = np.append(obs, [[step]], axis=1) # add time step feature
obs = (obs - offset) * scale # center and scale observations
obs = obs.astype('float32')
action = agent.policy_predict(obs)
action = action_mapping(action, env.action_space.low[0],
env.action_space.high[0])
obs, reward, done, _ = env.step(np.squeeze(action))
rewards.append(reward)
step += 1e-3 # increment time step feature
if done:
break
return np.sum(rewards)
def collect_trajectories(env, agent, scaler, episodes):
all_obs, all_actions, all_rewards, all_unscaled_obs = [], [], [], []
for e in range(episodes):
obs, actions, rewards, unscaled_obs = run_train_episode(
env, agent, scaler)
all_obs.append(obs)
all_actions.append(actions)
all_rewards.append(rewards)
all_unscaled_obs.append(unscaled_obs)
scaler.update(np.concatenate(all_unscaled_obs)
) # update running statistics for scaling observations
return np.concatenate(all_obs), np.concatenate(
all_actions), np.concatenate(all_rewards)
def main():
env = gym.make(args.env)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
obs_dim += 1 # add 1 to obs dim for time step feature
scaler = Scaler(obs_dim)
model = MujocoModel(obs_dim, act_dim)
hyperparas = {
'act_dim': act_dim,
'policy_lr': model.policy_lr,
'value_lr': model.value_lr
}
alg = PPO(model, hyperparas)
agent = MujocoAgent(
alg, obs_dim, act_dim, args.kl_targ, loss_type=args.loss_type)
# run a few episodes to initialize scaler
collect_trajectories(env, agent, scaler, episodes=5)
episode = 0
while episode < args.num_episodes:
obs, actions, rewards = collect_trajectories(
env, agent, scaler, episodes=args.episodes_per_batch)
episode += args.episodes_per_batch
pred_values = agent.value_predict(obs)
# scale rewards
scale_rewards = rewards * (1 - args.gamma)
discount_sum_rewards = calc_discount_sum_rewards(
scale_rewards, args.gamma)
discount_sum_rewards = discount_sum_rewards.astype('float32')
advantages = calc_gae(scale_rewards, pred_values, args.gamma, args.lam)
# normalize advantages
advantages = (advantages - advantages.mean()) / (
advantages.std() + 1e-6)
advantages = advantages.astype('float32')
policy_loss, kl = agent.policy_learn(obs, actions, advantages)
value_loss = agent.value_learn(obs, discount_sum_rewards)
logger.info(
'Episode {}, Train reward: {}, Policy loss: {}, KL: {}, Value loss: {}'
.format(episode,
np.sum(rewards) / args.episodes_per_batch, policy_loss, kl,
value_loss))
if episode % (args.episodes_per_batch * 5) == 0:
eval_reward = run_evaluate_episode(env, agent, scaler)
logger.info('Episode {}, Evaluate reward: {}'.format(
episode, eval_reward))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--env',
type=str,
help='Mujoco environment name',
default='HalfCheetah-v2')
parser.add_argument(
'--num_episodes',
type=int,
help='Number of episodes to run',
default=10000)
parser.add_argument(
'--gamma', type=float, help='Discount factor', default=0.995)
parser.add_argument(
'--lam',
type=float,
help='Lambda for Generalized Advantage Estimation',
default=0.98)
parser.add_argument(
'--kl_targ', type=float, help='D_KL target value', default=0.003)
parser.add_argument(
'--episodes_per_batch',
type=int,
help='Number of episodes per training batch',
default=5)
parser.add_argument(
'--loss_type',
type=str,
help="Choose loss type of PPO algorithm, 'CLIP' or 'KLPEN'",
default='CLIP')
args = parser.parse_args()
import time
logger.set_dir('./log_dir/{}_{}'.format(args.loss_type, time.time()))
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import scipy.signal
__all__ = ['calc_discount_sum_rewards', 'calc_gae', 'Scaler']
"""
The following code are copied or modified from:
https://github.com/pat-coady/trpo
Written by Patrick Coady (pat-coady.github.io)
"""
def calc_discount_sum_rewards(rewards, gamma):
""" Calculate discounted forward sum of a sequence at each point """
return scipy.signal.lfilter([1.0], [1.0, -gamma], rewards[::-1])[::-1]
def calc_gae(rewards, values, gamma, lam):
""" Calculate generalized advantage estimator.
See: https://arxiv.org/pdf/1506.02438.pdf
"""
# temporal differences
tds = rewards - values + np.append(values[1:] * gamma, 0)
advantages = calc_discount_sum_rewards(tds, gamma * lam)
return advantages
class Scaler(object):
""" Generate scale and offset based on running mean and stddev along axis=0
offset = running mean
scale = 1 / (stddev + 0.1) / 3 (i.e. 3x stddev = +/- 1.0)
"""
def __init__(self, obs_dim):
"""
Args:
obs_dim: dimension of axis=1
"""
self.vars = np.zeros(obs_dim)
self.means = np.zeros(obs_dim)
self.cnt = 0
self.first_pass = True
def update(self, x):
""" Update running mean and variance (this is an exact method)
Args:
x: NumPy array, shape = (N, obs_dim)
see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-
variance-of-two-groups-given-known-group-variances-mean
"""
if self.first_pass:
self.means = np.mean(x, axis=0)
self.vars = np.var(x, axis=0)
self.cnt = x.shape[0]
self.first_pass = False
else:
n = x.shape[0]
new_data_var = np.var(x, axis=0)
new_data_mean = np.mean(x, axis=0)
new_data_mean_sq = np.square(new_data_mean)
new_means = (
(self.means * self.cnt) + (new_data_mean * n)) / (self.cnt + n)
self.vars = (((self.cnt * (self.vars + np.square(self.means))) +
(n * (new_data_var + new_data_mean_sq))) /
(self.cnt + n) - np.square(new_means))
self.vars = np.maximum(
0.0, self.vars) # occasionally goes negative, clip
self.means = new_means
self.cnt += n
def get(self):
""" returns 2-tuple: (scale, offset) """
return 1 / (np.sqrt(self.vars) + 0.1) / 3, self.means
......@@ -5,7 +5,6 @@ Based on PARL, train a agent to play CartPole game with policy gradient algorith
### Dependencies:
+ python2.7 or python3.5+
+ [PARL](https://github.com/PaddlePaddle/PARL)
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.ddpg import *
from parl.algorithms.dqn import *
from parl.algorithms.policy_gradient import *
from parl.algorithms.ddpg import *
from parl.algorithms.ppo import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from copy import deepcopy
from paddle import fluid
from parl.framework.algorithm_base import Algorithm
__all__ = ['PPO']
class PPO(Algorithm):
def __init__(self, model, hyperparas):
Algorithm.__init__(self, model, hyperparas)
# Used to calculate probability of action in old policy
self.old_policy_model = deepcopy(model.policy_model)
# fetch hyper parameters
self.act_dim = hyperparas['act_dim']
self.policy_lr = hyperparas['policy_lr']
self.value_lr = hyperparas['value_lr']
if 'epsilon' in hyperparas:
self.epsilon = hyperparas['epsilon']
else:
self.epsilon = 0.2 # default
def _calc_logprob(self, actions, means, logvars):
""" Calculate log probabilities of actions, when given means and logvars
of normal distribution.
The constant sqrt(2 * pi) is omitted, which will be eliminated in later.
Args:
actions: shape (batch_size, act_dim)
means: shape (batch_size, act_dim)
logvars: shape (act_dim)
Returns:
logprob: shape (batch_size)
"""
exp_item = layers.elementwise_div(
layers.square(actions - means), layers.exp(logvars), axis=1)
exp_item = -0.5 * layers.reduce_sum(exp_item, dim=1)
vars_item = -0.5 * layers.reduce_sum(logvars)
logprob = exp_item + vars_item
return logprob
def _calc_kl(self, means, logvars, old_means, old_logvars):
""" Calculate KL divergence between old and new distributions
See: https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence
Args:
means: shape (batch_size, act_dim)
logvars: shape (act_dim)
old_means: shape (batch_size, act_dim)
old_logvars: shape (act_dim)
Returns:
kl: shape (batch_size)
"""
log_det_cov_old = layers.reduce_sum(old_logvars)
log_det_cov_new = layers.reduce_sum(logvars)
tr_old_new = layers.reduce_sum(layers.exp(old_logvars - logvars))
kl = 0.5 * (layers.reduce_sum(
layers.square(means - old_means) / layers.exp(logvars), dim=1) + (
log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim)
return kl
def define_predict(self, obs):
""" Use policy model of self.model to predict means and logvars of actions
"""
means, logvars = self.model.policy(obs)
return means
def define_sample(self, obs):
""" Use policy model of self.model to sample actions
"""
sampled_act = self.model.policy_sample(obs)
return sampled_act
def define_policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
2. KLPEN loss: Adaptive KL Penalty Objective
See: https://arxiv.org/pdf/1707.02286.pdf
Args:
obs: Tensor, (batch_size, obs_dim)
actions: Tensor, (batch_size, act_dim)
advantages: Tensor (batch_size, )
beta: Tensor (1) or None
if None, use CLIP Loss; else, use KLPEN loss.
"""
old_means, old_logvars = self.old_policy_model.policy(obs)
old_means.stop_gradient = True
old_logvars.stop_gradient = True
old_logprob = self._calc_logprob(actions, old_means, old_logvars)
means, logvars = self.model.policy(obs)
logprob = self._calc_logprob(actions, means, logvars)
kl = self._calc_kl(means, logvars, old_means, old_logvars)
kl = layers.reduce_mean(kl)
if beta is None: # Clipped Surrogate Objective
pg_ratio = layers.exp(logprob - old_logprob)
clipped_pg_ratio = layers.clip(pg_ratio, 1 - self.epsilon,
1 + self.epsilon)
surrogate_loss = layers.elementwise_min(
advantages * pg_ratio, advantages * clipped_pg_ratio)
loss = 0 - layers.reduce_mean(surrogate_loss)
else: # Adaptive KL Penalty Objective
# policy gradient loss
loss1 = 0 - layers.reduce_mean(
advantages * layers.exp(logprob - old_logprob))
# adaptive kl loss
loss2 = kl * beta
loss = loss1 + loss2
optimizer = fluid.optimizer.AdamOptimizer(self.policy_lr)
optimizer.minimize(loss)
return loss, kl
def define_value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.model.value(obs)
def define_value_learn(self, obs, val):
""" Learn value model with square error cost
"""
predict_val = self.model.value(obs)
loss = layers.square_error_cost(predict_val, val)
loss = layers.reduce_mean(loss)
optimizer = fluid.optimizer.AdamOptimizer(self.value_lr)
optimizer.minimize(loss)
return loss
def sync_old_policy(self, gpu_id):
""" Synchronize parameters of self.model.policy_model to self.old_policy_model
"""
self.model.policy_model.sync_params_to(
self.old_policy_model, gpu_id=gpu_id)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
from parl.utils import action_mapping
class TestUtils(unittest.TestCase):
def test_action_mapping(self):
origin_act = np.array([-1.0, 0.0, 1.0])
mapped_act = action_mapping(origin_act, 0.0, 1.0)
self.assertListEqual(list(mapped_act), [0.0, 0.5, 1.0])
mapped_act = action_mapping(origin_act, -2.0, 2.0)
self.assertListEqual(list(mapped_act), [-2.0, 0.0, 2.0])
mapped_act = action_mapping(origin_act, -5.0, 10.0)
self.assertListEqual(list(mapped_act), [-5.0, 2.5, 10.0])
if __name__ == '__main__':
unittest.main()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['has_func']
__all__ = ['has_func', 'action_mapping']
def has_func(obj, fun):
......@@ -26,3 +26,21 @@ def has_func(obj, fun):
"""
check_fun = getattr(obj, fun, None)
return callable(check_fun)
def action_mapping(model_output_act, low_bound, high_bound):
""" mapping action space [-1, 1] of model output
to new action space [low_bound, high_bound].
Args:
model_output_act: np.array, which value is in [-1, 1]
low_bound: float, low bound of env action space
high_bound: float, high bound of env action space
Returns:
action: np.array, which value is in [low_bound, high_bound]
"""
assert high_bound > low_bound
action = low_bound + (model_output_act - (-1.0)) * (
(high_bound - low_bound) / 2.0)
return action
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册