未验证 提交 2deefa8f 编写于 作者: L LI Yunxiang 提交者: GitHub

add torch ppo (#213)

* add ppo

* fix bugs

* yapf
上级 2c7340f7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
def get_args():
parser = argparse.ArgumentParser(description='RL')
parser.add_argument(
'--lr', type=float, default=3e-4, help='learning rate (default: 3e-4)')
parser.add_argument(
'--eps',
type=float,
default=1e-5,
help='RMSprop optimizer epsilon (default: 1e-5)')
parser.add_argument(
'--gamma',
type=float,
default=0.99,
help='discount factor for rewards (default: 0.99)')
parser.add_argument(
'--gae-lambda',
type=float,
default=0.95,
help='gae lambda parameter (default: 0.95)')
parser.add_argument(
'--entropy-coef',
type=float,
default=0.,
help='entropy term coefficient (default: 0.)')
parser.add_argument(
'--value-loss-coef',
type=float,
default=0.5,
help='value loss coefficient (default: 0.5)')
parser.add_argument(
'--max-grad-norm',
type=float,
default=0.5,
help='max norm of gradients (default: 0.5)')
parser.add_argument(
'--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
'--num-steps',
type=int,
default=2048,
help='number of maximum forward steps in ppo (default: 2048)')
parser.add_argument(
'--ppo-epoch',
type=int,
default=10,
help='number of ppo epochs (default: 10)')
parser.add_argument(
'--num-mini-batch',
type=int,
default=32,
help='number of batches for ppo (default: 32)')
parser.add_argument(
'--clip-param',
type=float,
default=0.2,
help='ppo clip parameter (default: 0.2)')
parser.add_argument(
'--log-interval',
type=int,
default=1,
help='log interval, one log per n updates (default: 1)')
parser.add_argument(
'--eval-interval',
type=int,
default=10,
help='eval interval, one eval per n updates (default: 10)')
parser.add_argument(
'--num-env-steps',
type=int,
default=10e5,
help='number of environment steps to train (default: 10e5)')
parser.add_argument(
'--env-name',
default='Hopper-v2',
help='environment to train on (default: Hopper-v2)')
parser.add_argument(
'--use-linear-lr-decay',
action='store_true',
default=False,
help='use a linear schedule on the learning rate')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
return args
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import utils
from wrapper import make_env
def evaluate(agent, ob_rms, env_name, seed, device):
if seed != None:
seed += 1
eval_envs = make_env(env_name, seed, None)
vec_norm = utils.get_vec_normalize(eval_envs)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = ob_rms
eval_episode_rewards = []
obs = eval_envs.reset()
eval_masks = torch.zeros(1, 1, device=device)
while len(eval_episode_rewards) < 10:
with torch.no_grad():
action = agent.predict(obs)
# Obser reward and next obs
obs, _, done, infos = eval_envs.step(action)
eval_masks = torch.tensor(
[[0.0] if done_ else [1.0] for done_ in done],
dtype=torch.float32,
device=device)
for info in infos:
if 'episode' in info.keys():
eval_episode_rewards.append(info['episode']['r'])
eval_envs.close()
print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
len(eval_episode_rewards), np.mean(eval_episode_rewards)))
return np.mean(eval_episode_rewards)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, device):
self.alg = algorithm
self.device = device
def predict(self, obs):
obs = torch.from_numpy(obs).float().to(self.device)
action = self.alg.predict(obs)
return action.cpu().numpy()
def sample(self, obs):
obs = torch.from_numpy(obs).to(self.device)
value, action, action_log_probs = self.alg.sample(obs)
return value.cpu().numpy(), action.cpu().numpy(), \
action_log_probs.cpu().numpy()
def learn(self, next_value, gamma, gae_lambda, ppo_epoch, num_mini_batch,
rollouts):
value_loss_epoch = 0
action_loss_epoch = 0
dist_entropy_epoch = 0
for e in range(ppo_epoch):
data_generator = rollouts.sample_batch(next_value, gamma,
gae_lambda, num_mini_batch)
for sample in data_generator:
obs_batch, actions_batch, \
value_preds_batch, return_batch, old_action_log_probs_batch, \
adv_targ = sample
obs_batch = torch.from_numpy(obs_batch).to('cuda')
actions_batch = torch.from_numpy(actions_batch).to('cuda').to(
'cuda')
value_preds_batch = torch.from_numpy(value_preds_batch).to(
'cuda')
return_batch = torch.from_numpy(return_batch).to('cuda')
old_action_log_probs_batch = torch.from_numpy(
old_action_log_probs_batch).to('cuda')
adv_targ = torch.from_numpy(adv_targ).to('cuda')
value_loss, action_loss, dist_entropy = self.alg.learn(
obs_batch, actions_batch, value_preds_batch, return_batch,
old_action_log_probs_batch, adv_targ)
value_loss_epoch += value_loss
action_loss_epoch += action_loss
dist_entropy_epoch += dist_entropy
num_updates = ppo_epoch * num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
return value_loss_epoch, action_loss_epoch, dist_entropy_epoch
def value(self, obs):
obs = torch.from_numpy(obs).to(self.device)
return self.alg.value(obs).cpu().numpy()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
class MujocoModel(parl.Model):
def __init__(self, obs_dim, act_dim):
super(MujocoModel, self).__init__()
self.actor = Actor(obs_dim, act_dim)
self.critic = Critic(obs_dim)
def policy(self, obs):
return self.actor(obs)
def value(self, obs):
return self.critic(obs)
class Actor(parl.Model):
def __init__(self, obs_dim, act_dim):
super(Actor, self).__init__()
self.fc1 = nn.Linear(obs_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc_mean = nn.Linear(64, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim))
def forward(self, obs):
x = torch.tanh(self.fc1(obs))
x = torch.tanh(self.fc2(x))
mean = self.fc_mean(x)
return mean, self.log_std
class Critic(parl.Model):
def __init__(self, obs_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(obs_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, obs):
x = torch.tanh(self.fc1(obs))
x = torch.tanh(self.fc2(x))
value = self.fc3(x)
return value
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
class RolloutStorage(object):
def __init__(self, num_steps, obs_dim, act_dim):
self.num_steps = num_steps
self.obs_dim = obs_dim
self.act_dim = act_dim
self.obs = np.zeros((num_steps + 1, obs_dim), dtype='float32')
self.actions = np.zeros((num_steps, act_dim), dtype='float32')
self.value_preds = np.zeros((num_steps + 1, ), dtype='float32')
self.returns = np.zeros((num_steps + 1, ), dtype='float32')
self.action_log_probs = np.zeros((num_steps, ), dtype='float32')
self.rewards = np.zeros((num_steps, ), dtype='float32')
self.masks = np.ones((num_steps + 1, ), dtype='bool')
self.bad_masks = np.ones((num_steps + 1, ), dtype='bool')
self.step = 0
def append(self, obs, actions, action_log_probs, value_preds, rewards,
masks, bad_masks):
"""
print("obs")
print(obs)
print("masks")
print(masks)
print("rewards")
print(rewards)
exit()
"""
self.obs[self.step + 1] = obs
self.actions[self.step] = actions
self.rewards[self.step] = rewards
self.action_log_probs[self.step] = action_log_probs
self.value_preds[self.step] = value_preds
self.masks[self.step + 1] = masks
self.bad_masks[self.step + 1] = bad_masks
self.step = (self.step + 1) % self.num_steps
def sample_batch(self,
next_value,
gamma,
gae_lambda,
num_mini_batch,
mini_batch_size=None):
# calculate return and advantage first
self.compute_returns(next_value, gamma, gae_lambda)
advantages = self.returns[:-1] - self.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (
advantages.std() + 1e-5)
# generate sample batch
mini_batch_size = self.num_steps // num_mini_batch
sampler = BatchSampler(
SubsetRandomSampler(range(self.num_steps)),
mini_batch_size,
drop_last=True)
for indices in sampler:
obs_batch = self.obs[:-1][indices]
actions_batch = self.actions[indices]
value_preds_batch = self.value_preds[:-1][indices]
returns_batch = self.returns[:-1][indices]
old_action_log_probs_batch = self.action_log_probs[indices]
value_preds_batch = value_preds_batch.reshape(-1, 1)
returns_batch = returns_batch.reshape(-1, 1)
old_action_log_probs_batch = old_action_log_probs_batch.reshape(
-1, 1)
adv_targ = advantages[indices]
adv_targ = adv_targ.reshape(-1, 1)
yield obs_batch, actions_batch, value_preds_batch, returns_batch, old_action_log_probs_batch, adv_targ
def after_update(self):
self.obs[0] = np.copy(self.obs[-1])
self.masks[0] = np.copy(self.masks[-1])
self.bad_masks[0] = np.copy(self.bad_masks[-1])
def compute_returns(self, next_value, gamma, gae_lambda):
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size)):
delta = self.rewards[step] + gamma * self.value_preds[
step + 1] * self.masks[step + 1] - self.value_preds[step]
gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae
gae = gae * self.bad_masks[step + 1]
self.returns[step] = gae + self.value_preds[step]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# modified from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
import copy
import os
from collections import deque
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import utils
from arguments import get_args
from wrapper import make_env
from mujoco_model import MujocoModel
from parl.algorithms import PPO
from mujoco_agent import MujocoAgent
from storage import RolloutStorage
from evaluation import evaluate
def main():
args = get_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.set_num_threads(1)
device = torch.device("cuda:0" if args.cuda else "cpu")
envs = make_env(args.env_name, args.seed, args.gamma)
model = MujocoModel(envs.observation_space.shape[0],
envs.action_space.shape[0])
model.to(device)
algorithm = PPO(
model,
args.clip_param,
args.value_loss_coef,
args.entropy_coef,
initial_lr=args.lr,
eps=args.eps,
max_grad_norm=args.max_grad_norm)
agent = MujocoAgent(algorithm, device)
rollouts = RolloutStorage(args.num_steps, envs.observation_space.shape[0],
envs.action_space.shape[0])
obs = envs.reset()
rollouts.obs[0] = np.copy(obs)
episode_rewards = deque(maxlen=10)
num_updates = int(args.num_env_steps) // args.num_steps
for j in range(num_updates):
if args.use_linear_lr_decay:
# decrease learning rate linearly
utils.update_linear_schedule(algorithm.optimizer, j, num_updates,
args.lr)
for step in range(args.num_steps):
# Sample actions
with torch.no_grad():
value, action, action_log_prob = agent.sample(
rollouts.obs[step]) # why use obs from rollouts???有病吧
# Obser reward and next obs
obs, reward, done, infos = envs.step(action)
for info in infos:
if 'episode' in info.keys():
episode_rewards.append(info['episode']['r'])
# If done then clean the history of observations.
masks = torch.FloatTensor(
[[0.0] if done_ else [1.0] for done_ in done])
bad_masks = torch.FloatTensor(
[[0.0] if 'bad_transition' in info.keys() else [1.0]
for info in infos])
rollouts.append(obs, action, action_log_prob, value, reward, masks,
bad_masks)
with torch.no_grad():
next_value = agent.value(rollouts.obs[-1])
value_loss, action_loss, dist_entropy = agent.learn(
next_value, args.gamma, args.gae_lambda, args.ppo_epoch,
args.num_mini_batch, rollouts)
rollouts.after_update()
if j % args.log_interval == 0 and len(episode_rewards) > 1:
total_num_steps = (j + 1) * args.num_steps
print(
"Updates {}, num timesteps {},\n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
.format(j, total_num_steps, len(episode_rewards),
np.mean(episode_rewards), np.median(episode_rewards),
np.min(episode_rewards), np.max(episode_rewards),
dist_entropy, value_loss, action_loss))
if (args.eval_interval is not None and len(episode_rewards) > 1
and j % args.eval_interval == 0):
ob_rms = utils.get_vec_normalize(envs).ob_rms
eval_mean_reward = evaluate(agent, ob_rms, args.env_name,
args.seed, device)
if __name__ == "__main__":
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
import torch
import torch.nn as nn
from wrapper import VecNormalize
def get_vec_normalize(venv):
if isinstance(venv, VecNormalize):
return venv
elif hasattr(venv, 'venv'):
return get_vec_normalize(venv.venv)
return None
def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
"""Decreases the learning rate linearly"""
lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def init(module, weight_init, bias_init, gain=1):
weight_init(module.weight.data, gain=gain)
bias_init(module.bias.data)
return module
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Simplified version of https://github.com/ShangtongZhang/DeepRL/blob/master/deep_rl/component/envs.py
import numpy as np
import gym
from gym.core import Wrapper
import time
class TimeLimitMask(gym.Wrapper):
def step(self, action):
obs, rew, done, info = self.env.step(action)
if done and self.env._max_episode_steps == self.env._elapsed_steps:
info['bad_transition'] = True
return obs, rew, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class MonitorEnv(gym.Wrapper):
def __init__(self, env):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
self.rewards = None
def step(self, action):
ob, rew, done, info = self.env.step(action)
self.update(ob, rew, done, info)
return (ob, rew, done, info)
def update(self, ob, rew, done, info):
self.rewards.append(rew)
if done:
eprew = sum(self.rewards)
eplen = len(self.rewards)
epinfo = {
"r": round(eprew, 6),
"l": eplen,
"t": round(time.time() - self.tstart, 6)
}
assert isinstance(info, dict)
info['episode'] = epinfo
self.reset()
def reset(self, **kwargs):
self.rewards = []
return self.env.reset(**kwargs)
class VectorEnv(gym.Wrapper):
def step(self, action):
ob, rew, done, info = self.env.step(action)
ob = np.array(ob)
ob = ob[np.newaxis, :]
rew = np.array([rew])
done = np.array([done])
info = [info]
return (ob, rew, done, info)
class RunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var,
batch_count)
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var,
batch_count):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
class VecNormalize(gym.Wrapper):
def __init__(self,
env,
ob=True,
ret=True,
clipob=10.,
cliprew=10.,
gamma=0.99,
epsilon=1e-8):
Wrapper.__init__(self, env=env)
observation_space = env.observation_space.shape[0]
self.ob_rms = RunningMeanStd(shape=observation_space) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.gamma = gamma
self.epsilon = epsilon
self.ret = np.zeros(1)
self.training = True
def step(self, action):
ob, rew, new, info = self.env.step(action)
self.ret = self.ret * self.gamma + rew
# normalize observation
ob = self._obfilt(ob)
# normalize reward
if self.ret_rms:
self.ret_rms.update(self.ret)
rew = np.clip(rew / np.sqrt(self.ret_rms.var + self.epsilon),
-self.cliprew, self.cliprew)
self.ret[new] = 0.
return ob, rew, new, info
def reset(self):
self.ret = np.zeros(1)
ob = self.env.reset()
return self._obfilt(ob)
def _obfilt(self, ob, update=True):
if self.ob_rms:
if self.training and update:
self.ob_rms.update(ob)
ob = np.clip((ob - self.ob_rms.mean) /
np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob,
self.clipob)
return ob
else:
return ob
def train(self):
self.training = True
def eval(self):
self.trainint = False
def make_env(env_name, seed, gamma):
env = gym.make(env_name)
env.seed(seed)
env = TimeLimitMask(env)
env = MonitorEnv(env)
env = VectorEnv(env)
if gamma is None:
env = VecNormalize(env, ret=False)
else:
env = VecNormalize(env, gamma=gamma)
return env
......@@ -16,4 +16,5 @@ from parl.algorithms.torch.ddqn import *
from parl.algorithms.torch.dqn import *
from parl.algorithms.torch.a2c import *
from parl.algorithms.torch.td3 import *
from parl.algorithms.torch.ppo import *
from parl.algorithms.torch.policy_gradient import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
__all__ = ['PPO']
class PPO(parl.Algorithm):
def __init__(self,
model,
clip_param,
value_loss_coef,
entropy_coef,
initial_lr,
eps=None,
max_grad_norm=None,
use_clipped_value_loss=True):
self.model = model
self.clip_param = clip_param
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self.optimizer = optim.Adam(model.parameters(), lr=initial_lr, eps=eps)
def learn(self, obs_batch, actions_batch, value_preds_batch, return_batch,
old_action_log_probs_batch, adv_targ):
values = self.model.value(obs_batch)
mean, log_std = self.model.policy(obs_batch)
dist = Normal(mean, log_std.exp())
action_log_probs = dist.log_prob(actions_batch).sum(-1, keepdim=True)
dist_entropy = dist.entropy().sum(-1).mean()
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()
if self.use_clipped_value_loss:
value_pred_clipped = value_preds_batch + \
(values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (values - return_batch).pow(2)
value_losses_clipped = (value_pred_clipped - return_batch).pow(2)
value_loss = 0.5 * torch.max(value_losses,
value_losses_clipped).mean()
else:
value_loss = 0.5 * (return_batch - values).pow(2).mean()
self.optimizer.zero_grad()
(value_loss * self.value_loss_coef + action_loss -
dist_entropy * self.entropy_coef).backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
return value_loss.item(), action_loss.item(), dist_entropy.item()
def sample(self, obs):
value = self.model.value(obs)
mean, log_std = self.model.policy(obs)
dist = Normal(mean, log_std.exp())
action = dist.sample()
action_log_probs = dist.log_prob(action).sum(-1, keepdim=True)
return value, action, action_log_probs
def predict(self, obs):
mean, _ = self.model.policy(obs)
return mean
def value(self, obs):
return self.model.value(obs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册