提交 e30a3d3c 编写于 作者: Z zhangyinmin

modify the gae recomputation; add/update ppo config/entry.

上级 8fffde51
...@@ -600,7 +600,7 @@ class RegressionHead(nn.Module): ...@@ -600,7 +600,7 @@ class RegressionHead(nn.Module):
class ReparameterizationHead(nn.Module): class ReparameterizationHead(nn.Module):
default_sigma_type = ['fixed', 'independent', 'conditioned'] default_sigma_type = ['fixed', 'independent', 'conditioned']
default_bound_type = ['tanh'] default_bound_type = ['tanh', None]
def __init__( def __init__(
self, self,
......
...@@ -33,7 +33,7 @@ class PPOPolicy(Policy): ...@@ -33,7 +33,7 @@ class PPOPolicy(Policy):
priority=False, priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False, priority_IS_weight=False,
recompute_adv=False, recompute_adv=True,
continuous=True, continuous=True,
learn=dict( learn=dict(
# (bool) Whether to use multi gpu # (bool) Whether to use multi gpu
...@@ -90,10 +90,10 @@ class PPOPolicy(Policy): ...@@ -90,10 +90,10 @@ class PPOPolicy(Policy):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight) torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
# self._model._actor[-1].weight.data.mul_(0.1)
if self._continuous: if self._continuous:
# init log sigma # init log sigma
# torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5) if hasattr(self._model.actor_head, 'log_sigma_param'):
torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5)
for m in list(self._model.critic.modules()) + list(self._model.actor.modules()): for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
# orthogonal initialization # orthogonal initialization
...@@ -131,8 +131,6 @@ class PPOPolicy(Policy): ...@@ -131,8 +131,6 @@ class PPOPolicy(Policy):
# Main model # Main model
self._learn_model.reset() self._learn_model.reset()
from torch.optim.lr_scheduler import LambdaLR
# self._lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lambda epoch: 1 - epoch / 1500.0)
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
r""" r"""
...@@ -163,16 +161,17 @@ class PPOPolicy(Policy): ...@@ -163,16 +161,17 @@ class PPOPolicy(Policy):
for epoch in range(self._cfg.learn.epoch_per_collect): for epoch in range(self._cfg.learn.epoch_per_collect):
if self._recompute_adv: if self._recompute_adv:
with torch.no_grad(): with torch.no_grad():
obs = torch.cat([data['obs'], data['next_obs'][-1:]]) # obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value = self._learn_model.forward(obs, mode='compute_critic')['value'] value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
if self._value_norm: if self._value_norm:
value *= self._running_mean_std.std value *= self._running_mean_std.std
next_value *= self._running_mean_std.std
gae_data_ = gae_data(value, data['reward'], data['done']) gae_data_ = gae_data(value, next_value, data['reward'], data['done'])
# GAE need (T, B) shape input and return (T, B) output # GAE need (T, B) shape input and return (T, B) output
data['adv'] = gae(gae_data_, self._gamma, self._gae_lambda) data['adv'] = gae(gae_data_, self._gamma, self._gae_lambda)
value = value[:-1] # value = value[:-1]
unnormalized_returns = value + data['adv'] unnormalized_returns = value + data['adv']
if self._value_norm: if self._value_norm:
...@@ -186,10 +185,8 @@ class PPOPolicy(Policy): ...@@ -186,10 +185,8 @@ class PPOPolicy(Policy):
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
adv = batch['adv'] adv = batch['adv']
# with torch.no_grad():
# batch['return'] = batch['value'] + adv
if self._adv_norm: if self._adv_norm:
# Normalize advantage in a total train_batch # Normalize advantage in a train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8) adv = (adv - adv.mean()) / (adv.std() + 1e-8)
# Calculate ppo error # Calculate ppo error
...@@ -231,11 +228,9 @@ class PPOPolicy(Policy): ...@@ -231,11 +228,9 @@ class PPOPolicy(Policy):
{ {
'mu_mean': output['logit'][0].mean().item(), 'mu_mean': output['logit'][0].mean().item(),
'sigma_mean': output['logit'][1].mean().item(), 'sigma_mean': output['logit'][1].mean().item(),
# 'sigma_grad': self._model.actor_head.log_sigma_param.grad.data.mean().item(),
} }
) )
return_infos.append(return_info) return_infos.append(return_info)
# self._lr_scheduler.step()
return return_infos return return_infos
def _state_dict_learn(self) -> Dict[str, Any]: def _state_dict_learn(self) -> Dict[str, Any]:
...@@ -284,7 +279,6 @@ class PPOPolicy(Policy): ...@@ -284,7 +279,6 @@ class PPOPolicy(Policy):
if self._continuous: if self._continuous:
(mu, sigma), value = output['logit'], output['value'] (mu, sigma), value = output['logit'], output['value']
dist = Independent(Normal(mu, sigma), 1) dist = Independent(Normal(mu, sigma), 1)
# action = torch.clamp(dist.sample(), min=-1, max=1)
output['action'] = dist.sample() output['action'] = dist.sample()
if self._cuda: if self._cuda:
output = to_device(output, 'cpu') output = to_device(output, 'cpu')
...@@ -326,8 +320,7 @@ class PPOPolicy(Policy): ...@@ -326,8 +320,7 @@ class PPOPolicy(Policy):
data = to_device(data, self._device) data = to_device(data, self._device)
# adder is defined in _init_collect # adder is defined in _init_collect
if self._cfg.learn.ignore_done: if self._cfg.learn.ignore_done:
for i in range(len(data)): data[-1]['done'] = False
data[i]['done'] = False
if data[-1]['done']: if data[-1]['done']:
last_value = torch.zeros(1) last_value = torch.zeros(1)
...@@ -346,7 +339,7 @@ class PPOPolicy(Policy): ...@@ -346,7 +339,7 @@ class PPOPolicy(Policy):
for i in range(len(data)): for i in range(len(data)):
data[i]['value'] /= self._running_mean_std.std data[i]['value'] /= self._running_mean_std.std
# remove next_obs for save memory when not recompute adv # remove next_obs for save memory when not recompute adv
if not self._recompute_adv: if not self._recompute_adv:
for i in range(len(data)): for i in range(len(data)):
data[i].pop('next_obs') data[i].pop('next_obs')
...@@ -383,8 +376,6 @@ class PPOPolicy(Policy): ...@@ -383,8 +376,6 @@ class PPOPolicy(Policy):
output = self._eval_model.forward(data, mode='compute_actor') output = self._eval_model.forward(data, mode='compute_actor')
if self._continuous: if self._continuous:
(mu, sigma) = output['logit'] (mu, sigma) = output['logit']
# dist = Independent(Normal(mu, sigma), 1)
# action = torch.clamp(dist.sample(), min=-1, max=1)
output.update({'action': mu}) output.update({'action': mu})
if self._cuda: if self._cuda:
output = to_device(output, 'cpu') output = to_device(output, 'cpu')
......
...@@ -32,12 +32,13 @@ class Adder(object): ...@@ -32,12 +32,13 @@ class Adder(object):
Returns: Returns:
- data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv' - data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv'
""" """
value = torch.stack([d['value'] for d in data] + [last_value]) value = torch.stack([d['value'] for d in data])
next_value = torch.stack([d['value'] for d in data][1:] + [last_value])
reward = torch.stack([d['reward'] for d in data]) reward = torch.stack([d['reward'] for d in data])
if cuda: if cuda:
value = value.cuda() value = value.cuda()
reward = reward.cuda() reward = reward.cuda()
adv = gae(gae_data(value, reward, None), gamma, gae_lambda) adv = gae(gae_data(value, next_value, reward, None), gamma, gae_lambda)
if cuda: if cuda:
adv = adv.cpu() adv = adv.cpu()
for i in range(len(data)): for i in range(len(data)):
......
...@@ -2,7 +2,7 @@ from collections import namedtuple ...@@ -2,7 +2,7 @@ from collections import namedtuple
import torch import torch
from ding.hpc_rl import hpc_wrapper from ding.hpc_rl import hpc_wrapper
gae_data = namedtuple('gae_data', ['value', 'reward', 'done']) gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done'])
def shape_fn_gae(args, kwargs): def shape_fn_gae(args, kwargs):
...@@ -43,17 +43,16 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F ...@@ -43,17 +43,16 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value
function, this operation is implemented in collector for packing trajectory. function, this operation is implemented in collector for packing trajectory.
""" """
value, reward, done = data value, next_value, reward, done = data
if done is None: if done is None:
delta = reward + gamma * value[1:] - value[:-1] done = torch.zeros_like(reward, device=reward.device)
else:
delta = reward + (1 - done) * gamma * value[1:] - value[:-1] delta = reward + (1 - done) * gamma * next_value - value
factor = gamma * lambda_ factor = gamma * lambda_
adv = torch.zeros_like(reward) adv = torch.zeros_like(reward, device=reward.device)
gae_item = 0. gae_item = 0.
denom = 0.
for t in reversed(range(reward.shape[0])): for t in reversed(range(reward.shape[0])):
denom = 1 + lambda_ * denom gae_item = delta[t] + factor * gae_item * (1 - done[t])
gae_item = denom * delta[t] + factor * gae_item adv[t] += gae_item
adv[t] += gae_item / denom
return adv return adv
...@@ -2,7 +2,7 @@ from easydict import EasyDict ...@@ -2,7 +2,7 @@ from easydict import EasyDict
pendulum_ppo_config = dict( pendulum_ppo_config = dict(
env=dict( env=dict(
collector_env_num=16, collector_env_num=1,
evaluator_env_num=5, evaluator_env_num=5,
act_scale=True, act_scale=True,
n_evaluator_episode=5, n_evaluator_episode=5,
...@@ -12,7 +12,7 @@ pendulum_ppo_config = dict( ...@@ -12,7 +12,7 @@ pendulum_ppo_config = dict(
cuda=False, cuda=False,
on_policy=True, on_policy=True,
continuous=True, continuous=True,
recompute_adv=True, recompute_adv=False,
model=dict( model=dict(
obs_shape=3, obs_shape=3,
action_shape=1, action_shape=1,
...@@ -20,25 +20,25 @@ pendulum_ppo_config = dict( ...@@ -20,25 +20,25 @@ pendulum_ppo_config = dict(
continuous=True, continuous=True,
actor_head_layer_num=0, actor_head_layer_num=0,
critic_head_layer_num=0, critic_head_layer_num=0,
sigma_type='fixed', sigma_type='conditioned',
bound_type='tanh', bound_type='tanh',
), ),
learn=dict( learn=dict(
epoch_per_collect=10, epoch_per_collect=10,
batch_size=128, batch_size=32,
learning_rate=1e-3, learning_rate=3e-5,
value_weight=0.5, value_weight=0.5,
entropy_weight=0.0, entropy_weight=0.0,
clip_ratio=0.2, clip_ratio=0.2,
adv_norm=True, adv_norm=False,
value_norm=True, value_norm=True,
ignore_done=True, ignore_done=False,
), ),
collect=dict( collect=dict(
n_sample=3200, n_sample=200,
unroll_len=1, unroll_len=1,
discount_factor=0.95, discount_factor=0.9,
gae_lambda=0.95, gae_lambda=1.,
), ),
eval=dict(evaluator=dict(eval_freq=200, )) eval=dict(evaluator=dict(eval_freq=200, ))
), ),
......
...@@ -5,7 +5,7 @@ hopper_ppo_default_config = dict( ...@@ -5,7 +5,7 @@ hopper_ppo_default_config = dict(
env_id='Hopper-v3', env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ), norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ), norm_reward=dict(use_norm=False, ),
collector_env_num=64, collector_env_num=8,
evaluator_env_num=10, evaluator_env_num=10,
use_act_scale=True, use_act_scale=True,
n_evaluator_episode=10, n_evaluator_episode=10,
...@@ -14,6 +14,7 @@ hopper_ppo_default_config = dict( ...@@ -14,6 +14,7 @@ hopper_ppo_default_config = dict(
policy=dict( policy=dict(
cuda=True, cuda=True,
on_policy=True, on_policy=True,
recompute_adv=True,
model=dict( model=dict(
obs_shape=11, obs_shape=11,
action_shape=3, action_shape=3,
...@@ -28,7 +29,7 @@ hopper_ppo_default_config = dict( ...@@ -28,7 +29,7 @@ hopper_ppo_default_config = dict(
entropy_weight=0.0, entropy_weight=0.0,
clip_ratio=0.2, clip_ratio=0.2,
adv_norm=True, adv_norm=True,
recompute_adv=True, value_norm=True,
), ),
collect=dict( collect=dict(
n_sample=2048, n_sample=2048,
......
import os
import gym
from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs import PendulumEnv
from dizoo.mujoco.envs.mujoco_env import MujocoEnv
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
from dizoo.mujoco.config.hopper_ppo_default_config import hopper_ppo_default_config
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed, dynamic_seed=True)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
if __name__ == "__main__":
main(hopper_ppo_default_config)
import gym import gym
import numpy as np import numpy as np
import pybulletgym
from ding.envs import ObsNormEnv, RewardNormEnv from ding.envs import ObsNormEnv, RewardNormEnv
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册