diff --git a/ding/model/common/head.py b/ding/model/common/head.py index 02af5c55ad7228aceb32ff1401c5b2efdd1298b8..498c94f7a5cc3be9b4d8af47d4225e43c92ce8b8 100644 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -600,7 +600,7 @@ class RegressionHead(nn.Module): class ReparameterizationHead(nn.Module): default_sigma_type = ['fixed', 'independent', 'conditioned'] - default_bound_type = ['tanh'] + default_bound_type = ['tanh', None] def __init__( self, diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 948c2672d153f7c48e3ffec25ba22aae88a50a97..8b4eb2e941b35f5689301b35887af99f07a81719 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -33,7 +33,7 @@ class PPOPolicy(Policy): priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, - recompute_adv=False, + recompute_adv=True, continuous=True, learn=dict( # (bool) Whether to use multi gpu @@ -90,10 +90,10 @@ class PPOPolicy(Policy): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - # self._model._actor[-1].weight.data.mul_(0.1) if self._continuous: # init log sigma - # torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5) + if hasattr(self._model.actor_head, 'log_sigma_param'): + torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5) for m in list(self._model.critic.modules()) + list(self._model.actor.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization @@ -131,8 +131,6 @@ class PPOPolicy(Policy): # Main model self._learn_model.reset() - from torch.optim.lr_scheduler import LambdaLR - # self._lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lambda epoch: 1 - epoch / 1500.0) def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: r""" @@ -163,16 +161,17 @@ class PPOPolicy(Policy): for epoch in range(self._cfg.learn.epoch_per_collect): if self._recompute_adv: with torch.no_grad(): - obs = torch.cat([data['obs'], data['next_obs'][-1:]]) - value = self._learn_model.forward(obs, mode='compute_critic')['value'] - + # obs = torch.cat([data['obs'], data['next_obs'][-1:]]) + value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] + next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] if self._value_norm: value *= self._running_mean_std.std + next_value *= self._running_mean_std.std - gae_data_ = gae_data(value, data['reward'], data['done']) + gae_data_ = gae_data(value, next_value, data['reward'], data['done']) # GAE need (T, B) shape input and return (T, B) output data['adv'] = gae(gae_data_, self._gamma, self._gae_lambda) - value = value[:-1] + # value = value[:-1] unnormalized_returns = value + data['adv'] if self._value_norm: @@ -186,10 +185,8 @@ class PPOPolicy(Policy): for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') adv = batch['adv'] - # with torch.no_grad(): - # batch['return'] = batch['value'] + adv if self._adv_norm: - # Normalize advantage in a total train_batch + # Normalize advantage in a train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) # Calculate ppo error @@ -231,11 +228,9 @@ class PPOPolicy(Policy): { 'mu_mean': output['logit'][0].mean().item(), 'sigma_mean': output['logit'][1].mean().item(), - # 'sigma_grad': self._model.actor_head.log_sigma_param.grad.data.mean().item(), } ) return_infos.append(return_info) - # self._lr_scheduler.step() return return_infos def _state_dict_learn(self) -> Dict[str, Any]: @@ -284,7 +279,6 @@ class PPOPolicy(Policy): if self._continuous: (mu, sigma), value = output['logit'], output['value'] dist = Independent(Normal(mu, sigma), 1) - # action = torch.clamp(dist.sample(), min=-1, max=1) output['action'] = dist.sample() if self._cuda: output = to_device(output, 'cpu') @@ -326,8 +320,7 @@ class PPOPolicy(Policy): data = to_device(data, self._device) # adder is defined in _init_collect if self._cfg.learn.ignore_done: - for i in range(len(data)): - data[i]['done'] = False + data[-1]['done'] = False if data[-1]['done']: last_value = torch.zeros(1) @@ -346,7 +339,7 @@ class PPOPolicy(Policy): for i in range(len(data)): data[i]['value'] /= self._running_mean_std.std - # remove next_obs for save memory when not recompute adv + # remove next_obs for save memory when not recompute adv if not self._recompute_adv: for i in range(len(data)): data[i].pop('next_obs') @@ -383,8 +376,6 @@ class PPOPolicy(Policy): output = self._eval_model.forward(data, mode='compute_actor') if self._continuous: (mu, sigma) = output['logit'] - # dist = Independent(Normal(mu, sigma), 1) - # action = torch.clamp(dist.sample(), min=-1, max=1) output.update({'action': mu}) if self._cuda: output = to_device(output, 'cpu') diff --git a/ding/rl_utils/adder.py b/ding/rl_utils/adder.py index 75c0f4f78ac9f15c995f0ec2c07f1f8bd6413646..d0c12b528e66458079eb867affe2fd340b40dc2d 100644 --- a/ding/rl_utils/adder.py +++ b/ding/rl_utils/adder.py @@ -32,12 +32,13 @@ class Adder(object): Returns: - data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv' """ - value = torch.stack([d['value'] for d in data] + [last_value]) + value = torch.stack([d['value'] for d in data]) + next_value = torch.stack([d['value'] for d in data][1:] + [last_value]) reward = torch.stack([d['reward'] for d in data]) if cuda: value = value.cuda() reward = reward.cuda() - adv = gae(gae_data(value, reward, None), gamma, gae_lambda) + adv = gae(gae_data(value, next_value, reward, None), gamma, gae_lambda) if cuda: adv = adv.cpu() for i in range(len(data)): diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index 79d946da99218b56aed1b8f91dd35878928c5bfc..f597249e533fcbdf34761a1c0b779a63e69e2f56 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -2,7 +2,7 @@ from collections import namedtuple import torch from ding.hpc_rl import hpc_wrapper -gae_data = namedtuple('gae_data', ['value', 'reward', 'done']) +gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done']) def shape_fn_gae(args, kwargs): @@ -43,17 +43,16 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value function, this operation is implemented in collector for packing trajectory. """ - value, reward, done = data + value, next_value, reward, done = data if done is None: - delta = reward + gamma * value[1:] - value[:-1] - else: - delta = reward + (1 - done) * gamma * value[1:] - value[:-1] + done = torch.zeros_like(reward, device=reward.device) + + delta = reward + (1 - done) * gamma * next_value - value factor = gamma * lambda_ - adv = torch.zeros_like(reward) + adv = torch.zeros_like(reward, device=reward.device) gae_item = 0. - denom = 0. + for t in reversed(range(reward.shape[0])): - denom = 1 + lambda_ * denom - gae_item = denom * delta[t] + factor * gae_item - adv[t] += gae_item / denom + gae_item = delta[t] + factor * gae_item * (1 - done[t]) + adv[t] += gae_item return adv diff --git a/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py b/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py index 65b53b36e71cb40ebbe9b34c21e684728539d16f..7f07972f1a7db6058bdc700501f613494f484a12 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py @@ -2,7 +2,7 @@ from easydict import EasyDict pendulum_ppo_config = dict( env=dict( - collector_env_num=16, + collector_env_num=1, evaluator_env_num=5, act_scale=True, n_evaluator_episode=5, @@ -12,7 +12,7 @@ pendulum_ppo_config = dict( cuda=False, on_policy=True, continuous=True, - recompute_adv=True, + recompute_adv=False, model=dict( obs_shape=3, action_shape=1, @@ -20,25 +20,25 @@ pendulum_ppo_config = dict( continuous=True, actor_head_layer_num=0, critic_head_layer_num=0, - sigma_type='fixed', + sigma_type='conditioned', bound_type='tanh', ), learn=dict( epoch_per_collect=10, - batch_size=128, - learning_rate=1e-3, + batch_size=32, + learning_rate=3e-5, value_weight=0.5, entropy_weight=0.0, clip_ratio=0.2, - adv_norm=True, + adv_norm=False, value_norm=True, - ignore_done=True, + ignore_done=False, ), collect=dict( - n_sample=3200, + n_sample=200, unroll_len=1, - discount_factor=0.95, - gae_lambda=0.95, + discount_factor=0.9, + gae_lambda=1., ), eval=dict(evaluator=dict(eval_freq=200, )) ), diff --git a/dizoo/mujoco/config/hopper_ppo_default_config.py b/dizoo/mujoco/config/hopper_ppo_default_config.py index 65e2ddc515091edec9ef46c8d3a42de38edaa066..f2dfbbc39e17ae3c8b267e9672c4f67ec15e2243 100644 --- a/dizoo/mujoco/config/hopper_ppo_default_config.py +++ b/dizoo/mujoco/config/hopper_ppo_default_config.py @@ -5,7 +5,7 @@ hopper_ppo_default_config = dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), norm_reward=dict(use_norm=False, ), - collector_env_num=64, + collector_env_num=8, evaluator_env_num=10, use_act_scale=True, n_evaluator_episode=10, @@ -14,6 +14,7 @@ hopper_ppo_default_config = dict( policy=dict( cuda=True, on_policy=True, + recompute_adv=True, model=dict( obs_shape=11, action_shape=3, @@ -28,7 +29,7 @@ hopper_ppo_default_config = dict( entropy_weight=0.0, clip_ratio=0.2, adv_norm=True, - recompute_adv=True, + value_norm=True, ), collect=dict( n_sample=2048, diff --git a/dizoo/mujoco/entry/__init__.py b/dizoo/mujoco/entry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dizoo/mujoco/entry/mujoco_ppo_main.py b/dizoo/mujoco/entry/mujoco_ppo_main.py new file mode 100644 index 0000000000000000000000000000000000000000..357171e5ebe05b322dc7da313d493a731a486967 --- /dev/null +++ b/dizoo/mujoco/entry/mujoco_ppo_main.py @@ -0,0 +1,58 @@ +import os +import gym +from tensorboardX import SummaryWriter +from easydict import EasyDict + +from ding.config import compile_config +from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer +from ding.envs import BaseEnvManager, DingEnvWrapper +from ding.policy import PPOPolicy +from ding.model import VAC +from ding.utils import set_pkg_seed +from dizoo.classic_control.pendulum.envs import PendulumEnv +from dizoo.mujoco.envs.mujoco_env import MujocoEnv +from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config +from dizoo.mujoco.config.hopper_ppo_default_config import hopper_ppo_default_config + + +def main(cfg, seed=0, max_iterations=int(1e10)): + cfg = compile_config( + cfg, + BaseEnvManager, + PPOPolicy, + BaseLearner, + SampleCollector, + BaseSerialEvaluator, + NaiveReplayBuffer, + save_cfg=True + ) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + collector_env = BaseEnvManager( + env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = BaseEnvManager( + env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager + ) + + collector_env.seed(seed, dynamic_seed=True) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + tb_logger = SummaryWriter(os.path.join('./log/', 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger) + collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger) + evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger) + + for _ in range(max_iterations): + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + new_data = collector.collect(train_iter=learner.train_iter) + learner.train(new_data, collector.envstep) + + +if __name__ == "__main__": + main(hopper_ppo_default_config) diff --git a/dizoo/mujoco/envs/mujoco_wrappers.py b/dizoo/mujoco/envs/mujoco_wrappers.py index c7454923f9912c6b34896fc5639594552a54b970..0a8b41939c01c925f6917d386d6f32fb230785ea 100644 --- a/dizoo/mujoco/envs/mujoco_wrappers.py +++ b/dizoo/mujoco/envs/mujoco_wrappers.py @@ -1,6 +1,5 @@ import gym import numpy as np -import pybulletgym from ding.envs import ObsNormEnv, RewardNormEnv