From f4440650aa42cbc33449736bb5441f72cca88068 Mon Sep 17 00:00:00 2001 From: zhangyinmin Date: Fri, 23 Jul 2021 11:10:22 +0800 Subject: [PATCH] modify the unittest for the gae; format code. --- ding/model/common/head.py | 20 ++++++------- ding/model/template/vac.py | 28 +++++++++---------- ding/policy/ppo.py | 10 +++---- ding/rl_utils/ppo.py | 2 +- ding/rl_utils/tests/test_gae.py | 6 ++-- .../mujoco/config/ant_ddpg_default_config.py | 8 ++---- dizoo/mujoco/config/ant_sac_default_config.py | 6 +--- dizoo/mujoco/config/ant_td3_default_config.py | 8 ++---- .../config/halfcheetah_ddpg_default_config.py | 8 ++---- .../config/halfcheetah_sac_default_config.py | 6 +--- .../config/halfcheetah_td3_default_config.py | 8 ++---- .../config/hopper_ddpg_default_config.py | 8 ++---- .../config/hopper_ppo_default_config.py | 2 +- .../config/hopper_sac_default_config.py | 6 +--- .../config/hopper_td3_default_config.py | 8 ++---- .../config/walker2d_ddpg_default_config.py | 8 ++---- .../config/walker2d_sac_default_config.py | 6 +--- .../config/walker2d_td3_default_config.py | 8 ++---- 18 files changed, 55 insertions(+), 101 deletions(-) diff --git a/ding/model/common/head.py b/ding/model/common/head.py index 498c94f..16da247 100644 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -603,15 +603,15 @@ class ReparameterizationHead(nn.Module): default_bound_type = ['tanh', None] def __init__( - self, - hidden_size: int, - output_size: int, - layer_num: int = 2, - sigma_type: Optional[str] = None, - fixed_sigma_value: Optional[float] = 1.0, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None, - bound_type: Optional[str] = None, + self, + hidden_size: int, + output_size: int, + layer_num: int = 2, + sigma_type: Optional[str] = None, + fixed_sigma_value: Optional[float] = 1.0, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + bound_type: Optional[str] = None, ) -> None: r""" Overview: @@ -672,7 +672,7 @@ class ReparameterizationHead(nn.Module): """ x = self.main(x) mu = self.mu(x) - if self.bound_type=='tanh': + if self.bound_type == 'tanh': mu = torch.tanh(mu) if self.sigma_type == 'fixed': sigma = self.sigma.to(mu.device) + torch.zeros_like(mu) # addition aims to broadcast shape diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 6160f7c..a024e50 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -18,20 +18,20 @@ class VAC(nn.Module): mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( - self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType], - share_encoder: bool = True, - continuous: bool = False, - encoder_hidden_size_list: SequenceType = [128, 128, 64], - actor_head_hidden_size: int = 64, - actor_head_layer_num: int = 1, - critic_head_hidden_size: int = 64, - critic_head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None, - sigma_type: Optional[str] = 'independent', - bound_type: Optional[str] = None, + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + share_encoder: bool = True, + continuous: bool = False, + encoder_hidden_size_list: SequenceType = [128, 128, 64], + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 1, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + sigma_type: Optional[str] = 'independent', + bound_type: Optional[str] = None, ) -> None: r""" Overview: diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 8b4eb2e..efd6338 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -131,7 +131,6 @@ class PPOPolicy(Policy): # Main model self._learn_model.reset() - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: r""" Overview: @@ -192,8 +191,8 @@ class PPOPolicy(Policy): # Calculate ppo error if self._continuous: ppo_batch = ppo_data( - output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], - adv, batch['return'], batch['weight'] + output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, + batch['return'], batch['weight'] ) ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio) else: @@ -333,8 +332,9 @@ class PPOPolicy(Policy): last_value *= self._running_mean_std.std for i in range(len(data)): data[i]['value'] *= self._running_mean_std.std - data = get_gae(data, to_device(last_value, self._device), gamma=self._gamma, - gae_lambda=self._gae_lambda, cuda=self._cuda) + data = get_gae( + data, to_device(last_value, self._device), gamma=self._gamma, gae_lambda=self._gae_lambda, cuda=self._cuda + ) if self._value_norm: for i in range(len(data)): data[i]['value'] /= self._running_mean_std.std diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index 1640417..3b1d6cd 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -171,7 +171,7 @@ def ppo_error_continuous( weight = torch.ones_like(adv) dist_new = Independent(Normal(mu_sigma_new[0], mu_sigma_new[1]), 1) - if len(mu_sigma_old[0].shape)==1: + if len(mu_sigma_old[0].shape) == 1: dist_old = Independent(Normal(mu_sigma_old[0].unsqueeze(-1), mu_sigma_old[1].unsqueeze(-1)), 1) else: dist_old = Independent(Normal(mu_sigma_old[0], mu_sigma_old[1]), 1) diff --git a/ding/rl_utils/tests/test_gae.py b/ding/rl_utils/tests/test_gae.py index e7d8ef4..0d1e852 100644 --- a/ding/rl_utils/tests/test_gae.py +++ b/ding/rl_utils/tests/test_gae.py @@ -6,8 +6,10 @@ from ding.rl_utils import gae_data, gae @pytest.mark.unittest def test_gae(): T, B = 32, 4 - value = torch.randn(T + 1, B) + value = torch.randn(T, B) + next_value = torch.randn(T, B) reward = torch.randn(T, B) - data = gae_data(value, reward) + done = torch.zeros((T, B)) + data = gae_data(value, next_value, reward, done) adv = gae(data) assert adv.shape == (T, B) diff --git a/dizoo/mujoco/config/ant_ddpg_default_config.py b/dizoo/mujoco/config/ant_ddpg_default_config.py index e6a97a5..3950e7e 100644 --- a/dizoo/mujoco/config/ant_ddpg_default_config.py +++ b/dizoo/mujoco/config/ant_ddpg_default_config.py @@ -39,11 +39,7 @@ ant_ddpg_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) ant_ddpg_default_config = EasyDict(ant_ddpg_default_config) @@ -59,7 +55,7 @@ ant_ddpg_default_create_config = dict( type='ddpg', import_names=['ding.policy.ddpg'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) ant_ddpg_default_create_config = EasyDict(ant_ddpg_default_create_config) create_config = ant_ddpg_default_create_config diff --git a/dizoo/mujoco/config/ant_sac_default_config.py b/dizoo/mujoco/config/ant_sac_default_config.py index 5f0bda2..859148d 100644 --- a/dizoo/mujoco/config/ant_sac_default_config.py +++ b/dizoo/mujoco/config/ant_sac_default_config.py @@ -42,11 +42,7 @@ ant_sac_default_config = dict( ), command=dict(), eval=dict(), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ), ) diff --git a/dizoo/mujoco/config/ant_td3_default_config.py b/dizoo/mujoco/config/ant_td3_default_config.py index 6e08369..f8555d7 100644 --- a/dizoo/mujoco/config/ant_td3_default_config.py +++ b/dizoo/mujoco/config/ant_td3_default_config.py @@ -44,11 +44,7 @@ ant_td3_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) @@ -65,7 +61,7 @@ ant_td3_default_create_config = dict( policy_type='td3', import_names=['ding.policy.td3'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) ant_td3_default_create_config = EasyDict(ant_td3_default_create_config) create_config = ant_td3_default_create_config diff --git a/dizoo/mujoco/config/halfcheetah_ddpg_default_config.py b/dizoo/mujoco/config/halfcheetah_ddpg_default_config.py index 07e4e8e..174620c 100644 --- a/dizoo/mujoco/config/halfcheetah_ddpg_default_config.py +++ b/dizoo/mujoco/config/halfcheetah_ddpg_default_config.py @@ -39,11 +39,7 @@ halfcheetah_ddpg_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) halfcheetah_ddpg_default_config = EasyDict(halfcheetah_ddpg_default_config) @@ -59,7 +55,7 @@ halfcheetah_ddpg_default_create_config = dict( type='ddpg', import_names=['ding.policy.ddpg'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) halfcheetah_ddpg_default_create_config = EasyDict(halfcheetah_ddpg_default_create_config) create_config = halfcheetah_ddpg_default_create_config diff --git a/dizoo/mujoco/config/halfcheetah_sac_default_config.py b/dizoo/mujoco/config/halfcheetah_sac_default_config.py index 31dcb70..e8f8f45 100644 --- a/dizoo/mujoco/config/halfcheetah_sac_default_config.py +++ b/dizoo/mujoco/config/halfcheetah_sac_default_config.py @@ -42,11 +42,7 @@ halfcheetah_sac_default_config = dict( ), command=dict(), eval=dict(), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ), ) diff --git a/dizoo/mujoco/config/halfcheetah_td3_default_config.py b/dizoo/mujoco/config/halfcheetah_td3_default_config.py index a784555..a4f41ac 100644 --- a/dizoo/mujoco/config/halfcheetah_td3_default_config.py +++ b/dizoo/mujoco/config/halfcheetah_td3_default_config.py @@ -44,11 +44,7 @@ halfcheetah_td3_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) @@ -65,7 +61,7 @@ halfcheetah_td3_default_create_config = dict( policy_type='td3', import_names=['ding.policy.td3'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) halfcheetah_td3_default_create_config = EasyDict(halfcheetah_td3_default_create_config) create_config = halfcheetah_td3_default_create_config diff --git a/dizoo/mujoco/config/hopper_ddpg_default_config.py b/dizoo/mujoco/config/hopper_ddpg_default_config.py index fae00dd..d2868aa 100644 --- a/dizoo/mujoco/config/hopper_ddpg_default_config.py +++ b/dizoo/mujoco/config/hopper_ddpg_default_config.py @@ -39,11 +39,7 @@ hopper_ddpg_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) hopper_ddpg_default_config = EasyDict(hopper_ddpg_default_config) @@ -59,7 +55,7 @@ hopper_ddpg_default_create_config = dict( type='ddpg', import_names=['ding.policy.ddpg'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) hopper_ddpg_default_create_config = EasyDict(hopper_ddpg_default_create_config) create_config = hopper_ddpg_default_create_config diff --git a/dizoo/mujoco/config/hopper_ppo_default_config.py b/dizoo/mujoco/config/hopper_ppo_default_config.py index f2dfbbc..c68359f 100644 --- a/dizoo/mujoco/config/hopper_ppo_default_config.py +++ b/dizoo/mujoco/config/hopper_ppo_default_config.py @@ -57,7 +57,7 @@ hopper_ppo_create_default_config = dict( type='ppo', import_names=['ding.policy.ppo'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) hopper_ppo_create_default_config = EasyDict(hopper_ppo_create_default_config) create_config = hopper_ppo_create_default_config diff --git a/dizoo/mujoco/config/hopper_sac_default_config.py b/dizoo/mujoco/config/hopper_sac_default_config.py index 6dbea96..c6c97a3 100644 --- a/dizoo/mujoco/config/hopper_sac_default_config.py +++ b/dizoo/mujoco/config/hopper_sac_default_config.py @@ -42,11 +42,7 @@ hopper_sac_default_config = dict( ), command=dict(), eval=dict(), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ), ) diff --git a/dizoo/mujoco/config/hopper_td3_default_config.py b/dizoo/mujoco/config/hopper_td3_default_config.py index 9f434df..731c16d 100644 --- a/dizoo/mujoco/config/hopper_td3_default_config.py +++ b/dizoo/mujoco/config/hopper_td3_default_config.py @@ -44,11 +44,7 @@ hopper_td3_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) @@ -65,7 +61,7 @@ hopper_td3_default_create_config = dict( policy_type='td3', import_names=['ding.policy.td3'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) hopper_td3_default_create_config = EasyDict(hopper_td3_default_create_config) create_config = hopper_td3_default_create_config diff --git a/dizoo/mujoco/config/walker2d_ddpg_default_config.py b/dizoo/mujoco/config/walker2d_ddpg_default_config.py index 6016ea7..b75b3a9 100644 --- a/dizoo/mujoco/config/walker2d_ddpg_default_config.py +++ b/dizoo/mujoco/config/walker2d_ddpg_default_config.py @@ -39,11 +39,7 @@ walker2d_ddpg_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) walker2d_ddpg_default_config = EasyDict(walker2d_ddpg_default_config) @@ -59,7 +55,7 @@ walker2d_ddpg_default_create_config = dict( type='ddpg', import_names=['ding.policy.ddpg'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) walker2d_ddpg_default_create_config = EasyDict(walker2d_ddpg_default_create_config) create_config = walker2d_ddpg_default_create_config diff --git a/dizoo/mujoco/config/walker2d_sac_default_config.py b/dizoo/mujoco/config/walker2d_sac_default_config.py index 2729ab1..4ba4b37 100644 --- a/dizoo/mujoco/config/walker2d_sac_default_config.py +++ b/dizoo/mujoco/config/walker2d_sac_default_config.py @@ -42,11 +42,7 @@ walker2d_sac_default_config = dict( ), command=dict(), eval=dict(), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ), ) diff --git a/dizoo/mujoco/config/walker2d_td3_default_config.py b/dizoo/mujoco/config/walker2d_td3_default_config.py index 0c77453..27d1f4f 100644 --- a/dizoo/mujoco/config/walker2d_td3_default_config.py +++ b/dizoo/mujoco/config/walker2d_td3_default_config.py @@ -44,11 +44,7 @@ walker2d_td3_default_config = dict( unroll_len=1, noise_sigma=0.1, ), - other=dict( - replay_buffer=dict( - replay_buffer_size=1000000, - ), - ), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ) ) @@ -65,7 +61,7 @@ walker2d_td3_default_create_config = dict( policy_type='td3', import_names=['ding.policy.td3'], ), - replay_buffer=dict(type='naive',), + replay_buffer=dict(type='naive', ), ) walker2d_td3_default_create_config = EasyDict(walker2d_td3_default_create_config) create_config = walker2d_td3_default_create_config -- GitLab