提交 f4440650 编写于 作者: Z zhangyinmin

modify the unittest for the gae; format code.

上级 e30a3d3c
......@@ -603,15 +603,15 @@ class ReparameterizationHead(nn.Module):
default_bound_type = ['tanh', None]
def __init__(
self,
hidden_size: int,
output_size: int,
layer_num: int = 2,
sigma_type: Optional[str] = None,
fixed_sigma_value: Optional[float] = 1.0,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
bound_type: Optional[str] = None,
self,
hidden_size: int,
output_size: int,
layer_num: int = 2,
sigma_type: Optional[str] = None,
fixed_sigma_value: Optional[float] = 1.0,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
bound_type: Optional[str] = None,
) -> None:
r"""
Overview:
......@@ -672,7 +672,7 @@ class ReparameterizationHead(nn.Module):
"""
x = self.main(x)
mu = self.mu(x)
if self.bound_type=='tanh':
if self.bound_type == 'tanh':
mu = torch.tanh(mu)
if self.sigma_type == 'fixed':
sigma = self.sigma.to(mu.device) + torch.zeros_like(mu) # addition aims to broadcast shape
......
......@@ -18,20 +18,20 @@ class VAC(nn.Module):
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
share_encoder: bool = True,
continuous: bool = False,
encoder_hidden_size_list: SequenceType = [128, 128, 64],
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
sigma_type: Optional[str] = 'independent',
bound_type: Optional[str] = None,
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
share_encoder: bool = True,
continuous: bool = False,
encoder_hidden_size_list: SequenceType = [128, 128, 64],
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
sigma_type: Optional[str] = 'independent',
bound_type: Optional[str] = None,
) -> None:
r"""
Overview:
......
......@@ -131,7 +131,6 @@ class PPOPolicy(Policy):
# Main model
self._learn_model.reset()
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
r"""
Overview:
......@@ -192,8 +191,8 @@ class PPOPolicy(Policy):
# Calculate ppo error
if self._continuous:
ppo_batch = ppo_data(
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'],
adv, batch['return'], batch['weight']
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
batch['return'], batch['weight']
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
else:
......@@ -333,8 +332,9 @@ class PPOPolicy(Policy):
last_value *= self._running_mean_std.std
for i in range(len(data)):
data[i]['value'] *= self._running_mean_std.std
data = get_gae(data, to_device(last_value, self._device), gamma=self._gamma,
gae_lambda=self._gae_lambda, cuda=self._cuda)
data = get_gae(
data, to_device(last_value, self._device), gamma=self._gamma, gae_lambda=self._gae_lambda, cuda=self._cuda
)
if self._value_norm:
for i in range(len(data)):
data[i]['value'] /= self._running_mean_std.std
......
......@@ -171,7 +171,7 @@ def ppo_error_continuous(
weight = torch.ones_like(adv)
dist_new = Independent(Normal(mu_sigma_new[0], mu_sigma_new[1]), 1)
if len(mu_sigma_old[0].shape)==1:
if len(mu_sigma_old[0].shape) == 1:
dist_old = Independent(Normal(mu_sigma_old[0].unsqueeze(-1), mu_sigma_old[1].unsqueeze(-1)), 1)
else:
dist_old = Independent(Normal(mu_sigma_old[0], mu_sigma_old[1]), 1)
......
......@@ -6,8 +6,10 @@ from ding.rl_utils import gae_data, gae
@pytest.mark.unittest
def test_gae():
T, B = 32, 4
value = torch.randn(T + 1, B)
value = torch.randn(T, B)
next_value = torch.randn(T, B)
reward = torch.randn(T, B)
data = gae_data(value, reward)
done = torch.zeros((T, B))
data = gae_data(value, next_value, reward, done)
adv = gae(data)
assert adv.shape == (T, B)
......@@ -39,11 +39,7 @@ ant_ddpg_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
ant_ddpg_default_config = EasyDict(ant_ddpg_default_config)
......@@ -59,7 +55,7 @@ ant_ddpg_default_create_config = dict(
type='ddpg',
import_names=['ding.policy.ddpg'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
ant_ddpg_default_create_config = EasyDict(ant_ddpg_default_create_config)
create_config = ant_ddpg_default_create_config
......@@ -42,11 +42,7 @@ ant_sac_default_config = dict(
),
command=dict(),
eval=dict(),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
......
......@@ -44,11 +44,7 @@ ant_td3_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
......@@ -65,7 +61,7 @@ ant_td3_default_create_config = dict(
policy_type='td3',
import_names=['ding.policy.td3'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
ant_td3_default_create_config = EasyDict(ant_td3_default_create_config)
create_config = ant_td3_default_create_config
......@@ -39,11 +39,7 @@ halfcheetah_ddpg_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
halfcheetah_ddpg_default_config = EasyDict(halfcheetah_ddpg_default_config)
......@@ -59,7 +55,7 @@ halfcheetah_ddpg_default_create_config = dict(
type='ddpg',
import_names=['ding.policy.ddpg'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
halfcheetah_ddpg_default_create_config = EasyDict(halfcheetah_ddpg_default_create_config)
create_config = halfcheetah_ddpg_default_create_config
......@@ -42,11 +42,7 @@ halfcheetah_sac_default_config = dict(
),
command=dict(),
eval=dict(),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
......
......@@ -44,11 +44,7 @@ halfcheetah_td3_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
......@@ -65,7 +61,7 @@ halfcheetah_td3_default_create_config = dict(
policy_type='td3',
import_names=['ding.policy.td3'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
halfcheetah_td3_default_create_config = EasyDict(halfcheetah_td3_default_create_config)
create_config = halfcheetah_td3_default_create_config
......@@ -39,11 +39,7 @@ hopper_ddpg_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
hopper_ddpg_default_config = EasyDict(hopper_ddpg_default_config)
......@@ -59,7 +55,7 @@ hopper_ddpg_default_create_config = dict(
type='ddpg',
import_names=['ding.policy.ddpg'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
hopper_ddpg_default_create_config = EasyDict(hopper_ddpg_default_create_config)
create_config = hopper_ddpg_default_create_config
......@@ -57,7 +57,7 @@ hopper_ppo_create_default_config = dict(
type='ppo',
import_names=['ding.policy.ppo'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
hopper_ppo_create_default_config = EasyDict(hopper_ppo_create_default_config)
create_config = hopper_ppo_create_default_config
......@@ -42,11 +42,7 @@ hopper_sac_default_config = dict(
),
command=dict(),
eval=dict(),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
......
......@@ -44,11 +44,7 @@ hopper_td3_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
......@@ -65,7 +61,7 @@ hopper_td3_default_create_config = dict(
policy_type='td3',
import_names=['ding.policy.td3'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
hopper_td3_default_create_config = EasyDict(hopper_td3_default_create_config)
create_config = hopper_td3_default_create_config
......@@ -39,11 +39,7 @@ walker2d_ddpg_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
walker2d_ddpg_default_config = EasyDict(walker2d_ddpg_default_config)
......@@ -59,7 +55,7 @@ walker2d_ddpg_default_create_config = dict(
type='ddpg',
import_names=['ding.policy.ddpg'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
walker2d_ddpg_default_create_config = EasyDict(walker2d_ddpg_default_create_config)
create_config = walker2d_ddpg_default_create_config
......@@ -42,11 +42,7 @@ walker2d_sac_default_config = dict(
),
command=dict(),
eval=dict(),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
......
......@@ -44,11 +44,7 @@ walker2d_td3_default_config = dict(
unroll_len=1,
noise_sigma=0.1,
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
),
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
......@@ -65,7 +61,7 @@ walker2d_td3_default_create_config = dict(
policy_type='td3',
import_names=['ding.policy.td3'],
),
replay_buffer=dict(type='naive',),
replay_buffer=dict(type='naive', ),
)
walker2d_td3_default_create_config = EasyDict(walker2d_td3_default_create_config)
create_config = walker2d_td3_default_create_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册