...
 
Commits (3)
    https://gitcode.net/opendilab/DI-engine/-/commit/293e5c36f22928ec565c9546d51f0e40827c2f89 polish(pu): update the current best config 2021-12-27T12:31:39+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/938fc921e9b5f508623eafcd9fca3d237a5b1e93 polish(pu): polish config 2021-12-27T14:33:17+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/18ca5a851ca84e567eeff3d54f036cb28a097b10 polish(pu): polish config 2021-12-27T21:06:24+08:00 puyuan1996 2402552459@qq.com
...@@ -52,26 +52,25 @@ class VanillaVAE(BaseVAE): ...@@ -52,26 +52,25 @@ class VanillaVAE(BaseVAE):
# Build Encoder # Build Encoder
# action # action
self.action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU()) self.encode_action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU())
# obs # obs
self.obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU()) self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.encoder = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU()) self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.mu_head = nn.Linear(hidden_dims[0], latent_dim) self.encode_mu_head = nn.Linear(hidden_dims[0], latent_dim)
self.var_head = nn.Linear(hidden_dims[0], latent_dim) self.encode_logvar_head = nn.Linear(hidden_dims[0], latent_dim)
# Build Decoder # Build Decoder
self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU()) self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.decoder_action = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU()) self.decode_action_head = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
self.decoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU()) self.decode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
# TODO(pu): tanh # TODO(pu): tanh
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh()) self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim) # self.decode_reconst_action_head = nn.Linear(hidden_dims[0], self.action_dim)
# residual prediction # residual prediction
self.prediction_head_1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU()) self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.prediction_head_2 = nn.Linear(hidden_dims[0], self.obs_dim) self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[0], self.obs_dim)
self.obs_encoding = None self.obs_encoding = None
...@@ -82,8 +81,8 @@ class VanillaVAE(BaseVAE): ...@@ -82,8 +81,8 @@ class VanillaVAE(BaseVAE):
:param input: (Tensor) Input tensor to encoder [N x C x H x W] :param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes :return: (Tensor) List of latent codes
""" """
action_encoding = self.action_head(input['action']) action_encoding = self.encode_action_head(input['action'])
obs_encoding = self.obs_head(input['obs']) obs_encoding = self.encode_obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network # obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
self.obs_encoding = obs_encoding self.obs_encoding = obs_encoding
...@@ -91,13 +90,13 @@ class VanillaVAE(BaseVAE): ...@@ -91,13 +90,13 @@ class VanillaVAE(BaseVAE):
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat? # input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding input = obs_encoding * action_encoding
result = self.encoder(input) result = self.encode_common(input)
result = torch.flatten(result, start_dim=1) result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components # Split the result into mu and var components
# of the latent Gaussian distribution # of the latent Gaussian distribution
mu = self.mu_head(result) mu = self.encode_mu_head(result)
log_var = self.var_head(result) log_var = self.encode_logvar_head(result)
return [mu, log_var] return [mu, log_var]
...@@ -108,14 +107,15 @@ class VanillaVAE(BaseVAE): ...@@ -108,14 +107,15 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D] :param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W]
""" """
action_decoding = self.decoder_action(torch.tanh(z)) # NOTE: tanh, here z is not bounded action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded
# action_decoding = self.decoder_action(z) # NOTE: tanh, here z is not bounded # action_decoding = self.decode_action_head(z) # NOTE: tanh, here z is not bounded
action_obs_decoding = action_decoding * self.obs_encoding action_obs_decoding = action_decoding + self.obs_encoding # TODO(pu): what about add, cat?
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding) # action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp) reconstruction_action = self.decode_reconst_action_head (action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp) predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp) predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return [reconstruction_action, predition_residual] return [reconstruction_action, predition_residual]
...@@ -126,14 +126,15 @@ class VanillaVAE(BaseVAE): ...@@ -126,14 +126,15 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D] :param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W]
""" """
self.obs_encoding = self.obs_head(obs) self.obs_encoding = self.encode_obs_head(obs)
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
action_decoding = self.decoder_action(z) action_decoding = self.decode_action_head(z)
# action_obs_decoding = action_decoding + self.obs_encoding # TODO(pu): what about add, cat?
action_obs_decoding = action_decoding * self.obs_encoding action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding) action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp) reconstruction_action = self.decode_reconst_action_head (action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp) predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp) predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return [reconstruction_action, predition_residual] return [reconstruction_action, predition_residual]
......
...@@ -317,7 +317,7 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -317,7 +317,7 @@ class TD3VAEPolicy(DDPGPolicy):
# latent space constraint (LSC) # latent space constraint (LSC)
# NOTE: using tanh is important, update latent_action using z, shape (128,6) # NOTE: using tanh is important, update latent_action using z, shape (128,6)
data['latent_action'] = torch.tanh(result['z'].clone().detach()) data['latent_action'] = torch.tanh(result['z'].clone().detach()) # NOTE: tanh
# data['latent_action'] = result['z'].clone().detach() # data['latent_action'] = result['z'].clone().detach()
self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int( self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
result['recons_action'].shape[0] * 0.02 result['recons_action'].shape[0] * 0.02
...@@ -373,9 +373,11 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -373,9 +373,11 @@ class TD3VAEPolicy(DDPGPolicy):
if F.mse_loss(result['prediction_residual'][i], if F.mse_loss(result['prediction_residual'][i],
true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean: true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
# NOTE: using tanh is important, update latent_action using z # NOTE: using tanh is important, update latent_action using z
data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach()) data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach()) # NOTE: tanh
# data['latent_action'][i] = result['z'][i].clone().detach() # data['latent_action'][i] = result['z'][i].clone().detach()
# update all latent action
# data['latent_action'] = torch.tanh(result['z'].clone().detach())
if self._reward_batch_norm: if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8) reward = (reward - reward.mean()) / (reward.std() + 1e-8)
......
...@@ -2,7 +2,7 @@ from easydict import EasyDict ...@@ -2,7 +2,7 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict( lunarlander_td3vae_config = dict(
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc3_upcr256_upcv100_kw0.01_pw0.01_dot_tanh', exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc1_upcr256_upcv10_kw0.01_pw0.01_dot_tanh',
env=dict( env=dict(
env_id='LunarLanderContinuous-v2', env_id='LunarLanderContinuous-v2',
collector_env_num=8, collector_env_num=8,
...@@ -25,9 +25,9 @@ lunarlander_td3vae_config = dict( ...@@ -25,9 +25,9 @@ lunarlander_td3vae_config = dict(
), ),
learn=dict( learn=dict(
warm_up_update=int(1e4), warm_up_update=int(1e4),
rl_vae_update_circle=3, # train rl 3 iter, vae 1 iter rl_vae_update_circle=1, # train rl 1 iter, vae 1 iter
update_per_collect_rl=256, update_per_collect_rl=256,
update_per_collect_vae=100, update_per_collect_vae=10,
batch_size=128, batch_size=128,
learning_rate_actor=3e-4, learning_rate_actor=3e-4,
learning_rate_critic=3e-4, learning_rate_critic=3e-4,
......