提交 293e5c36 编写于 作者: P puyuan1996

polish(pu): update the current best config

上级 70328aab
......@@ -56,18 +56,17 @@ class VanillaVAE(BaseVAE):
# obs
self.obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.encoder = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.encoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.mu_head = nn.Linear(hidden_dims[0], latent_dim)
self.var_head = nn.Linear(hidden_dims[0], latent_dim)
self.logvar_head = nn.Linear(hidden_dims[0], latent_dim)
# Build Decoder
self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.decoder_action = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
self.decoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
# TODO(pu): tanh
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim)
self.reconstruction_head = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.reconstruction_head = nn.Linear(hidden_dims[0], self.action_dim)
# residual prediction
self.prediction_head_1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
......@@ -91,13 +90,13 @@ class VanillaVAE(BaseVAE):
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding
result = self.encoder(input)
result = self.encoder_common(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.mu_head(result)
log_var = self.var_head(result)
log_var = self.logvar_head(result)
return [mu, log_var]
......
......@@ -2,10 +2,10 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc3_upcr256_upcv100_kw0.01_pw0.01_dot_tanh',
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc1_upcr256_upcv10_kw0.01_pw0.01_dot_tanh',
env=dict(
env_id='LunarLanderContinuous-v2',
collector_env_num=8,
collector_env_num=1,
evaluator_env_num=5,
# (bool) Scale output action into legal range.
act_scale=True,
......@@ -25,9 +25,10 @@ lunarlander_td3vae_config = dict(
),
learn=dict(
warm_up_update=int(1e4),
rl_vae_update_circle=3, # train rl 3 iter, vae 1 iter
rl_vae_update_circle=1, # train rl 1 iter, vae 1 iter
update_per_collect_rl=256,
update_per_collect_vae=100,
update_per_collect_vae=10,
batch_size=128,
learning_rate_actor=3e-4,
learning_rate_critic=3e-4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册