diff --git a/benchmark/torch/ppo/arguments.py b/benchmark/torch/ppo/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d5d33df54b4652a416f0f9bbb49c3d1bd4a522 --- /dev/null +++ b/benchmark/torch/ppo/arguments.py @@ -0,0 +1,103 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch + + +def get_args(): + parser = argparse.ArgumentParser(description='RL') + parser.add_argument( + '--lr', type=float, default=3e-4, help='learning rate (default: 3e-4)') + parser.add_argument( + '--eps', + type=float, + default=1e-5, + help='RMSprop optimizer epsilon (default: 1e-5)') + parser.add_argument( + '--gamma', + type=float, + default=0.99, + help='discount factor for rewards (default: 0.99)') + parser.add_argument( + '--gae-lambda', + type=float, + default=0.95, + help='gae lambda parameter (default: 0.95)') + parser.add_argument( + '--entropy-coef', + type=float, + default=0., + help='entropy term coefficient (default: 0.)') + parser.add_argument( + '--value-loss-coef', + type=float, + default=0.5, + help='value loss coefficient (default: 0.5)') + parser.add_argument( + '--max-grad-norm', + type=float, + default=0.5, + help='max norm of gradients (default: 0.5)') + parser.add_argument( + '--seed', type=int, default=1, help='random seed (default: 1)') + parser.add_argument( + '--num-steps', + type=int, + default=2048, + help='number of maximum forward steps in ppo (default: 2048)') + parser.add_argument( + '--ppo-epoch', + type=int, + default=10, + help='number of ppo epochs (default: 10)') + parser.add_argument( + '--num-mini-batch', + type=int, + default=32, + help='number of batches for ppo (default: 32)') + parser.add_argument( + '--clip-param', + type=float, + default=0.2, + help='ppo clip parameter (default: 0.2)') + parser.add_argument( + '--log-interval', + type=int, + default=1, + help='log interval, one log per n updates (default: 1)') + parser.add_argument( + '--eval-interval', + type=int, + default=10, + help='eval interval, one eval per n updates (default: 10)') + parser.add_argument( + '--num-env-steps', + type=int, + default=10e5, + help='number of environment steps to train (default: 10e5)') + parser.add_argument( + '--env-name', + default='Hopper-v2', + help='environment to train on (default: Hopper-v2)') + parser.add_argument( + '--use-linear-lr-decay', + action='store_true', + default=False, + help='use a linear schedule on the learning rate') + args = parser.parse_args() + + args.cuda = torch.cuda.is_available() + + return args diff --git a/benchmark/torch/ppo/evaluation.py b/benchmark/torch/ppo/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa020ca66a0c3a97d8deea55e37dabc4cf7512b --- /dev/null +++ b/benchmark/torch/ppo/evaluation.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import utils +from wrapper import make_env + + +def evaluate(agent, ob_rms, env_name, seed, device): + if seed != None: + seed += 1 + eval_envs = make_env(env_name, seed, None) + vec_norm = utils.get_vec_normalize(eval_envs) + if vec_norm is not None: + vec_norm.eval() + vec_norm.ob_rms = ob_rms + + eval_episode_rewards = [] + + obs = eval_envs.reset() + eval_masks = torch.zeros(1, 1, device=device) + + while len(eval_episode_rewards) < 10: + with torch.no_grad(): + action = agent.predict(obs) + + # Obser reward and next obs + obs, _, done, infos = eval_envs.step(action) + + eval_masks = torch.tensor( + [[0.0] if done_ else [1.0] for done_ in done], + dtype=torch.float32, + device=device) + + for info in infos: + if 'episode' in info.keys(): + eval_episode_rewards.append(info['episode']['r']) + + eval_envs.close() + + print(" Evaluation using {} episodes: mean reward {:.5f}\n".format( + len(eval_episode_rewards), np.mean(eval_episode_rewards))) + return np.mean(eval_episode_rewards) diff --git a/benchmark/torch/ppo/mujoco_agent.py b/benchmark/torch/ppo/mujoco_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..096f683f958829c0780ecc59d9ed144367c15f38 --- /dev/null +++ b/benchmark/torch/ppo/mujoco_agent.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch + + +class MujocoAgent(parl.Agent): + def __init__(self, algorithm, device): + self.alg = algorithm + self.device = device + + def predict(self, obs): + obs = torch.from_numpy(obs).float().to(self.device) + action = self.alg.predict(obs) + return action.cpu().numpy() + + def sample(self, obs): + obs = torch.from_numpy(obs).to(self.device) + value, action, action_log_probs = self.alg.sample(obs) + return value.cpu().numpy(), action.cpu().numpy(), \ + action_log_probs.cpu().numpy() + + def learn(self, next_value, gamma, gae_lambda, ppo_epoch, num_mini_batch, + rollouts): + value_loss_epoch = 0 + action_loss_epoch = 0 + dist_entropy_epoch = 0 + + for e in range(ppo_epoch): + data_generator = rollouts.sample_batch(next_value, gamma, + gae_lambda, num_mini_batch) + + for sample in data_generator: + obs_batch, actions_batch, \ + value_preds_batch, return_batch, old_action_log_probs_batch, \ + adv_targ = sample + + obs_batch = torch.from_numpy(obs_batch).to('cuda') + actions_batch = torch.from_numpy(actions_batch).to('cuda').to( + 'cuda') + value_preds_batch = torch.from_numpy(value_preds_batch).to( + 'cuda') + return_batch = torch.from_numpy(return_batch).to('cuda') + old_action_log_probs_batch = torch.from_numpy( + old_action_log_probs_batch).to('cuda') + adv_targ = torch.from_numpy(adv_targ).to('cuda') + + value_loss, action_loss, dist_entropy = self.alg.learn( + obs_batch, actions_batch, value_preds_batch, return_batch, + old_action_log_probs_batch, adv_targ) + + value_loss_epoch += value_loss + action_loss_epoch += action_loss + dist_entropy_epoch += dist_entropy + + num_updates = ppo_epoch * num_mini_batch + + value_loss_epoch /= num_updates + action_loss_epoch /= num_updates + dist_entropy_epoch /= num_updates + + return value_loss_epoch, action_loss_epoch, dist_entropy_epoch + + def value(self, obs): + obs = torch.from_numpy(obs).to(self.device) + return self.alg.value(obs).cpu().numpy() diff --git a/benchmark/torch/ppo/mujoco_model.py b/benchmark/torch/ppo/mujoco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..83b762da2bd5a922d2a20605df641b6aec0ad949 --- /dev/null +++ b/benchmark/torch/ppo/mujoco_model.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal + + +class MujocoModel(parl.Model): + def __init__(self, obs_dim, act_dim): + super(MujocoModel, self).__init__() + self.actor = Actor(obs_dim, act_dim) + self.critic = Critic(obs_dim) + + def policy(self, obs): + return self.actor(obs) + + def value(self, obs): + return self.critic(obs) + + +class Actor(parl.Model): + def __init__(self, obs_dim, act_dim): + super(Actor, self).__init__() + self.fc1 = nn.Linear(obs_dim, 64) + self.fc2 = nn.Linear(64, 64) + + self.fc_mean = nn.Linear(64, act_dim) + self.log_std = nn.Parameter(torch.zeros(act_dim)) + + def forward(self, obs): + x = torch.tanh(self.fc1(obs)) + x = torch.tanh(self.fc2(x)) + + mean = self.fc_mean(x) + return mean, self.log_std + + +class Critic(parl.Model): + def __init__(self, obs_dim): + super(Critic, self).__init__() + self.fc1 = nn.Linear(obs_dim, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 1) + + def forward(self, obs): + x = torch.tanh(self.fc1(obs)) + x = torch.tanh(self.fc2(x)) + value = self.fc3(x) + + return value diff --git a/benchmark/torch/ppo/storage.py b/benchmark/torch/ppo/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b986b670d545fb88938785fc812a320103023d5d --- /dev/null +++ b/benchmark/torch/ppo/storage.py @@ -0,0 +1,107 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler + + +class RolloutStorage(object): + def __init__(self, num_steps, obs_dim, act_dim): + self.num_steps = num_steps + self.obs_dim = obs_dim + self.act_dim = act_dim + + self.obs = np.zeros((num_steps + 1, obs_dim), dtype='float32') + self.actions = np.zeros((num_steps, act_dim), dtype='float32') + self.value_preds = np.zeros((num_steps + 1, ), dtype='float32') + self.returns = np.zeros((num_steps + 1, ), dtype='float32') + self.action_log_probs = np.zeros((num_steps, ), dtype='float32') + self.rewards = np.zeros((num_steps, ), dtype='float32') + + self.masks = np.ones((num_steps + 1, ), dtype='bool') + self.bad_masks = np.ones((num_steps + 1, ), dtype='bool') + + self.step = 0 + + def append(self, obs, actions, action_log_probs, value_preds, rewards, + masks, bad_masks): + """ + print("obs") + print(obs) + print("masks") + print(masks) + print("rewards") + print(rewards) + exit() + """ + self.obs[self.step + 1] = obs + self.actions[self.step] = actions + self.rewards[self.step] = rewards + self.action_log_probs[self.step] = action_log_probs + self.value_preds[self.step] = value_preds + self.masks[self.step + 1] = masks + self.bad_masks[self.step + 1] = bad_masks + + self.step = (self.step + 1) % self.num_steps + + def sample_batch(self, + next_value, + gamma, + gae_lambda, + num_mini_batch, + mini_batch_size=None): + # calculate return and advantage first + self.compute_returns(next_value, gamma, gae_lambda) + advantages = self.returns[:-1] - self.value_preds[:-1] + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-5) + + # generate sample batch + mini_batch_size = self.num_steps // num_mini_batch + sampler = BatchSampler( + SubsetRandomSampler(range(self.num_steps)), + mini_batch_size, + drop_last=True) + for indices in sampler: + obs_batch = self.obs[:-1][indices] + actions_batch = self.actions[indices] + value_preds_batch = self.value_preds[:-1][indices] + returns_batch = self.returns[:-1][indices] + old_action_log_probs_batch = self.action_log_probs[indices] + + value_preds_batch = value_preds_batch.reshape(-1, 1) + returns_batch = returns_batch.reshape(-1, 1) + old_action_log_probs_batch = old_action_log_probs_batch.reshape( + -1, 1) + + adv_targ = advantages[indices] + adv_targ = adv_targ.reshape(-1, 1) + + yield obs_batch, actions_batch, value_preds_batch, returns_batch, old_action_log_probs_batch, adv_targ + + def after_update(self): + self.obs[0] = np.copy(self.obs[-1]) + self.masks[0] = np.copy(self.masks[-1]) + self.bad_masks[0] = np.copy(self.bad_masks[-1]) + + def compute_returns(self, next_value, gamma, gae_lambda): + self.value_preds[-1] = next_value + gae = 0 + for step in reversed(range(self.rewards.size)): + delta = self.rewards[step] + gamma * self.value_preds[ + step + 1] * self.masks[step + 1] - self.value_preds[step] + gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae + gae = gae * self.bad_masks[step + 1] + self.returns[step] = gae + self.value_preds[step] diff --git a/benchmark/torch/ppo/train.py b/benchmark/torch/ppo/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb5dafbf4fbc6b96dc664030910446a7cfd46e1 --- /dev/null +++ b/benchmark/torch/ppo/train.py @@ -0,0 +1,128 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# modified from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail + +import copy +import os +from collections import deque + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import utils +from arguments import get_args +from wrapper import make_env +from mujoco_model import MujocoModel +from parl.algorithms import PPO +from mujoco_agent import MujocoAgent +from storage import RolloutStorage +from evaluation import evaluate + + +def main(): + args = get_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + torch.set_num_threads(1) + device = torch.device("cuda:0" if args.cuda else "cpu") + + envs = make_env(args.env_name, args.seed, args.gamma) + + model = MujocoModel(envs.observation_space.shape[0], + envs.action_space.shape[0]) + model.to(device) + + algorithm = PPO( + model, + args.clip_param, + args.value_loss_coef, + args.entropy_coef, + initial_lr=args.lr, + eps=args.eps, + max_grad_norm=args.max_grad_norm) + + agent = MujocoAgent(algorithm, device) + + rollouts = RolloutStorage(args.num_steps, envs.observation_space.shape[0], + envs.action_space.shape[0]) + + obs = envs.reset() + rollouts.obs[0] = np.copy(obs) + + episode_rewards = deque(maxlen=10) + + num_updates = int(args.num_env_steps) // args.num_steps + for j in range(num_updates): + + if args.use_linear_lr_decay: + # decrease learning rate linearly + utils.update_linear_schedule(algorithm.optimizer, j, num_updates, + args.lr) + + for step in range(args.num_steps): + # Sample actions + with torch.no_grad(): + value, action, action_log_prob = agent.sample( + rollouts.obs[step]) # why use obs from rollouts???有病吧 + + # Obser reward and next obs + obs, reward, done, infos = envs.step(action) + + for info in infos: + if 'episode' in info.keys(): + episode_rewards.append(info['episode']['r']) + + # If done then clean the history of observations. + masks = torch.FloatTensor( + [[0.0] if done_ else [1.0] for done_ in done]) + bad_masks = torch.FloatTensor( + [[0.0] if 'bad_transition' in info.keys() else [1.0] + for info in infos]) + rollouts.append(obs, action, action_log_prob, value, reward, masks, + bad_masks) + + with torch.no_grad(): + next_value = agent.value(rollouts.obs[-1]) + + value_loss, action_loss, dist_entropy = agent.learn( + next_value, args.gamma, args.gae_lambda, args.ppo_epoch, + args.num_mini_batch, rollouts) + + rollouts.after_update() + + if j % args.log_interval == 0 and len(episode_rewards) > 1: + total_num_steps = (j + 1) * args.num_steps + print( + "Updates {}, num timesteps {},\n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" + .format(j, total_num_steps, len(episode_rewards), + np.mean(episode_rewards), np.median(episode_rewards), + np.min(episode_rewards), np.max(episode_rewards), + dist_entropy, value_loss, action_loss)) + + if (args.eval_interval is not None and len(episode_rewards) > 1 + and j % args.eval_interval == 0): + ob_rms = utils.get_vec_normalize(envs).ob_rms + eval_mean_reward = evaluate(agent, ob_rms, args.env_name, + args.seed, device) + + +if __name__ == "__main__": + main() diff --git a/benchmark/torch/ppo/utils.py b/benchmark/torch/ppo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e276a7f0779cfb55b3ef92012f22a61b7937c62 --- /dev/null +++ b/benchmark/torch/ppo/utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os + +import torch +import torch.nn as nn + +from wrapper import VecNormalize + + +def get_vec_normalize(venv): + if isinstance(venv, VecNormalize): + return venv + elif hasattr(venv, 'venv'): + return get_vec_normalize(venv.venv) + + return None + + +def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): + """Decreases the learning rate linearly""" + lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def init(module, weight_init, bias_init, gain=1): + weight_init(module.weight.data, gain=gain) + bias_init(module.bias.data) + return module diff --git a/benchmark/torch/ppo/wrapper.py b/benchmark/torch/ppo/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a890db1d0e5ee2cc2131794d9317a76a55e16e83 --- /dev/null +++ b/benchmark/torch/ppo/wrapper.py @@ -0,0 +1,180 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Simplified version of https://github.com/ShangtongZhang/DeepRL/blob/master/deep_rl/component/envs.py + +import numpy as np +import gym +from gym.core import Wrapper +import time + + +class TimeLimitMask(gym.Wrapper): + def step(self, action): + obs, rew, done, info = self.env.step(action) + if done and self.env._max_episode_steps == self.env._elapsed_steps: + info['bad_transition'] = True + return obs, rew, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class MonitorEnv(gym.Wrapper): + def __init__(self, env): + Wrapper.__init__(self, env=env) + self.tstart = time.time() + self.rewards = None + + def step(self, action): + ob, rew, done, info = self.env.step(action) + self.update(ob, rew, done, info) + return (ob, rew, done, info) + + def update(self, ob, rew, done, info): + self.rewards.append(rew) + if done: + eprew = sum(self.rewards) + eplen = len(self.rewards) + epinfo = { + "r": round(eprew, 6), + "l": eplen, + "t": round(time.time() - self.tstart, 6) + } + assert isinstance(info, dict) + info['episode'] = epinfo + self.reset() + + def reset(self, **kwargs): + self.rewards = [] + return self.env.reset(**kwargs) + + +class VectorEnv(gym.Wrapper): + def step(self, action): + ob, rew, done, info = self.env.step(action) + ob = np.array(ob) + ob = ob[np.newaxis, :] + rew = np.array([rew]) + + done = np.array([done]) + + info = [info] + return (ob, rew, done, info) + + +class RunningMeanStd(object): + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + def __init__(self, epsilon=1e-4, shape=()): + self.mean = np.zeros(shape, 'float64') + self.var = np.ones(shape, 'float64') + self.count = epsilon + + def update(self, x): + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + self.mean, self.var, self.count = update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, + batch_count) + + +def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, + batch_count): + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return new_mean, new_var, new_count + + +class VecNormalize(gym.Wrapper): + def __init__(self, + env, + ob=True, + ret=True, + clipob=10., + cliprew=10., + gamma=0.99, + epsilon=1e-8): + Wrapper.__init__(self, env=env) + observation_space = env.observation_space.shape[0] + + self.ob_rms = RunningMeanStd(shape=observation_space) if ob else None + self.ret_rms = RunningMeanStd(shape=()) if ret else None + + self.clipob = clipob + self.cliprew = cliprew + self.gamma = gamma + self.epsilon = epsilon + self.ret = np.zeros(1) + self.training = True + + def step(self, action): + ob, rew, new, info = self.env.step(action) + self.ret = self.ret * self.gamma + rew + # normalize observation + ob = self._obfilt(ob) + # normalize reward + if self.ret_rms: + self.ret_rms.update(self.ret) + rew = np.clip(rew / np.sqrt(self.ret_rms.var + self.epsilon), + -self.cliprew, self.cliprew) + self.ret[new] = 0. + return ob, rew, new, info + + def reset(self): + self.ret = np.zeros(1) + ob = self.env.reset() + return self._obfilt(ob) + + def _obfilt(self, ob, update=True): + if self.ob_rms: + if self.training and update: + self.ob_rms.update(ob) + ob = np.clip((ob - self.ob_rms.mean) / + np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, + self.clipob) + return ob + else: + return ob + + def train(self): + self.training = True + + def eval(self): + self.trainint = False + + +def make_env(env_name, seed, gamma): + env = gym.make(env_name) + env.seed(seed) + env = TimeLimitMask(env) + env = MonitorEnv(env) + env = VectorEnv(env) + if gamma is None: + env = VecNormalize(env, ret=False) + else: + env = VecNormalize(env, gamma=gamma) + + return env diff --git a/parl/algorithms/torch/__init__.py b/parl/algorithms/torch/__init__.py index dc2aabac3dcdd239eba1b059fbf1283a192b4b10..9de7afbdd57305b1280b024556e0b1730bcbc494 100644 --- a/parl/algorithms/torch/__init__.py +++ b/parl/algorithms/torch/__init__.py @@ -16,4 +16,5 @@ from parl.algorithms.torch.ddqn import * from parl.algorithms.torch.dqn import * from parl.algorithms.torch.a2c import * from parl.algorithms.torch.td3 import * +from parl.algorithms.torch.ppo import * from parl.algorithms.torch.policy_gradient import * diff --git a/parl/algorithms/torch/ppo.py b/parl/algorithms/torch/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..7c838e896e26b35fa078d1db1323476fb776993f --- /dev/null +++ b/parl/algorithms/torch/ppo.py @@ -0,0 +1,94 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.distributions import Normal + +__all__ = ['PPO'] + + +class PPO(parl.Algorithm): + def __init__(self, + model, + clip_param, + value_loss_coef, + entropy_coef, + initial_lr, + eps=None, + max_grad_norm=None, + use_clipped_value_loss=True): + self.model = model + + self.clip_param = clip_param + + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + + self.max_grad_norm = max_grad_norm + self.use_clipped_value_loss = use_clipped_value_loss + + self.optimizer = optim.Adam(model.parameters(), lr=initial_lr, eps=eps) + + def learn(self, obs_batch, actions_batch, value_preds_batch, return_batch, + old_action_log_probs_batch, adv_targ): + values = self.model.value(obs_batch) + mean, log_std = self.model.policy(obs_batch) + dist = Normal(mean, log_std.exp()) + + action_log_probs = dist.log_prob(actions_batch).sum(-1, keepdim=True) + dist_entropy = dist.entropy().sum(-1).mean() + + ratio = torch.exp(action_log_probs - old_action_log_probs_batch) + surr1 = ratio * adv_targ + surr2 = torch.clamp(ratio, 1.0 - self.clip_param, + 1.0 + self.clip_param) * adv_targ + action_loss = -torch.min(surr1, surr2).mean() + + if self.use_clipped_value_loss: + value_pred_clipped = value_preds_batch + \ + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) + value_losses = (values - return_batch).pow(2) + value_losses_clipped = (value_pred_clipped - return_batch).pow(2) + value_loss = 0.5 * torch.max(value_losses, + value_losses_clipped).mean() + else: + value_loss = 0.5 * (return_batch - values).pow(2).mean() + + self.optimizer.zero_grad() + (value_loss * self.value_loss_coef + action_loss - + dist_entropy * self.entropy_coef).backward() + nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + + return value_loss.item(), action_loss.item(), dist_entropy.item() + + def sample(self, obs): + value = self.model.value(obs) + mean, log_std = self.model.policy(obs) + dist = Normal(mean, log_std.exp()) + action = dist.sample() + action_log_probs = dist.log_prob(action).sum(-1, keepdim=True) + + return value, action, action_log_probs + + def predict(self, obs): + mean, _ = self.model.policy(obs) + return mean + + def value(self, obs): + return self.model.value(obs)