...
 
Commits (3)
    https://gitcode.net/opendilab/DI-engine/-/commit/293e5c36f22928ec565c9546d51f0e40827c2f89 polish(pu): update the current best config 2021-12-27T12:31:39+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/938fc921e9b5f508623eafcd9fca3d237a5b1e93 polish(pu): polish config 2021-12-27T14:33:17+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/18ca5a851ca84e567eeff3d54f036cb28a097b10 polish(pu): polish config 2021-12-27T21:06:24+08:00 puyuan1996 2402552459@qq.com
......@@ -52,26 +52,25 @@ class VanillaVAE(BaseVAE):
# Build Encoder
# action
self.action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU())
self.encode_action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU())
# obs
self.obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.encoder = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.mu_head = nn.Linear(hidden_dims[0], latent_dim)
self.var_head = nn.Linear(hidden_dims[0], latent_dim)
self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.encode_mu_head = nn.Linear(hidden_dims[0], latent_dim)
self.encode_logvar_head = nn.Linear(hidden_dims[0], latent_dim)
# Build Decoder
self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.decoder_action = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
self.decoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.decode_action_head = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
self.decode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
# TODO(pu): tanh
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim)
self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.decode_reconst_action_head = nn.Linear(hidden_dims[0], self.action_dim)
# residual prediction
self.prediction_head_1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.prediction_head_2 = nn.Linear(hidden_dims[0], self.obs_dim)
self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[0], self.obs_dim)
self.obs_encoding = None
......@@ -82,8 +81,8 @@ class VanillaVAE(BaseVAE):
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
action_encoding = self.action_head(input['action'])
obs_encoding = self.obs_head(input['obs'])
action_encoding = self.encode_action_head(input['action'])
obs_encoding = self.encode_obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
self.obs_encoding = obs_encoding
......@@ -91,13 +90,13 @@ class VanillaVAE(BaseVAE):
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding
result = self.encoder(input)
result = self.encode_common(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.mu_head(result)
log_var = self.var_head(result)
mu = self.encode_mu_head(result)
log_var = self.encode_logvar_head(result)
return [mu, log_var]
......@@ -108,14 +107,15 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
action_decoding = self.decoder_action(torch.tanh(z)) # NOTE: tanh, here z is not bounded
# action_decoding = self.decoder_action(z) # NOTE: tanh, here z is not bounded
action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded
# action_decoding = self.decode_action_head(z) # NOTE: tanh, here z is not bounded
action_obs_decoding = action_decoding + self.obs_encoding # TODO(pu): what about add, cat?
# action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp)
reconstruction_action = self.decode_reconst_action_head (action_obs_decoding_tmp)
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return [reconstruction_action, predition_residual]
......@@ -126,14 +126,15 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
self.obs_encoding = self.obs_head(obs)
self.obs_encoding = self.encode_obs_head(obs)
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
action_decoding = self.decoder_action(z)
action_decoding = self.decode_action_head(z)
# action_obs_decoding = action_decoding + self.obs_encoding # TODO(pu): what about add, cat?
action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp)
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.decode_reconst_action_head (action_obs_decoding_tmp)
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return [reconstruction_action, predition_residual]
......
......@@ -317,7 +317,7 @@ class TD3VAEPolicy(DDPGPolicy):
# latent space constraint (LSC)
# NOTE: using tanh is important, update latent_action using z, shape (128,6)
data['latent_action'] = torch.tanh(result['z'].clone().detach())
data['latent_action'] = torch.tanh(result['z'].clone().detach()) # NOTE: tanh
# data['latent_action'] = result['z'].clone().detach()
self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
result['recons_action'].shape[0] * 0.02
......@@ -373,9 +373,11 @@ class TD3VAEPolicy(DDPGPolicy):
if F.mse_loss(result['prediction_residual'][i],
true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
# NOTE: using tanh is important, update latent_action using z
data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach())
data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach()) # NOTE: tanh
# data['latent_action'][i] = result['z'][i].clone().detach()
# update all latent action
# data['latent_action'] = torch.tanh(result['z'].clone().detach())
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
......
......@@ -2,7 +2,7 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc3_upcr256_upcv100_kw0.01_pw0.01_dot_tanh',
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc1_upcr256_upcv10_kw0.01_pw0.01_dot_tanh',
env=dict(
env_id='LunarLanderContinuous-v2',
collector_env_num=8,
......@@ -25,9 +25,9 @@ lunarlander_td3vae_config = dict(
),
learn=dict(
warm_up_update=int(1e4),
rl_vae_update_circle=3, # train rl 3 iter, vae 1 iter
rl_vae_update_circle=1, # train rl 1 iter, vae 1 iter
update_per_collect_rl=256,
update_per_collect_vae=100,
update_per_collect_vae=10,
batch_size=128,
learning_rate_actor=3e-4,
learning_rate_critic=3e-4,
......