提交 b93a380b 编写于 作者: P puyuan1996

polish(pu): polish config

上级 ec3a3618
......@@ -204,7 +204,8 @@ class VanillaVAE(BaseVAE):
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
return [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z] # recons_action, prediction_residual
def loss_function(self,
*args,
......
......@@ -231,6 +231,8 @@ class TD3VAEPolicy(DDPGPolicy):
lr=self._cfg.learn.learning_rate_vae,
)
self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4)
self.c_percentage_bound_lower = -1*torch.ones([6])
self.c_percentage_bound_upper = torch.ones([6])
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
......@@ -314,8 +316,6 @@ class TD3VAEPolicy(DDPGPolicy):
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(10,15):
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_update_freq == 0:
if data['vae_phase'][0].item() is True:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if self._cuda:
......@@ -335,6 +335,11 @@ class TD3VAEPolicy(DDPGPolicy):
true_residual = data['next_obs'] - data['obs']
result = result + [true_residual]
# latent space constraint (LSC)
# data['latent_action'] = torch.tanh(result[5].detach()) # TODO(pu): update latent_action z, shape (128,6)
# self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(result[0].shape[0] * 0.02), :] # values, indices
# self.c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(result[0].shape[0] * 0.98), :]
vae_loss = self._vae_model.loss_function(*result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
......@@ -364,8 +369,6 @@ class TD3VAEPolicy(DDPGPolicy):
**q_value_dict,
}
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(0,10):
# if data[0]['rl_phase'] is True:
else:
# ====================
# critic learn forward
......@@ -384,6 +387,8 @@ class TD3VAEPolicy(DDPGPolicy):
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
true_residual = data['next_obs'] - data['obs']
# Representation shift correction (RSC)
for i in range(result[1].shape[0]):
if F.mse_loss(result[1][i], true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
data['latent_action'][i] = torch.tanh(result[5][i].detach()) # TODO(pu): update latent_action z tanh
......@@ -435,20 +440,9 @@ class TD3VAEPolicy(DDPGPolicy):
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
# latent space constraint (LSC)
# result = self._vae_model(
# {'action': data['action'],
# 'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = torch.tanh(result[5].detach()) # TODO(pu): update latent_action z
# c_percentage_bound_low = data['latent_action'].sort(dim=0)[0][int(128 * 0.02), :]
# c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(128 * 0.98), :]
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') # latent action
# latent space constraint (LSC)
# for i in range(actor_data['action'].shape[-1]):
# # actor_data['action'][:, i] = copy.deepcopy(actor_data['action'][:, i].clamp(c_percentage_bound_low[i].item(), c_percentage_bound_upper[i].item()))
# actor_data['action'][:, i].clamp(c_percentage_bound_low[i].item(),
# c_percentage_bound_upper[i].item())
actor_data['obs'] = data['obs']
if self._twin_critic:
......@@ -540,6 +534,12 @@ class TD3VAEPolicy(DDPGPolicy):
with torch.no_grad():
output = self._collect_model.forward(data, mode='compute_actor', **kwargs)
output['latent_action'] = output['action']
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
......@@ -636,6 +636,12 @@ class TD3VAEPolicy(DDPGPolicy):
with torch.no_grad():
output = self._eval_model.forward(data, mode='compute_actor')
output['latent_action'] = output['action']
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
......
......@@ -2,50 +2,9 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_zrelabel_eins1280_rvuc10_upcr20_upcv100_noisefalse_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins1280_rvuc10_upcr20_upcv100_noisefalse_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1_upcr2_upcv2_noisetrue_rbs2e4', # TODO(pu) lr 3e-4 loss explode 45000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc20_upcr2_upcv200_noisetrue_rbs2e4', # TODO(pu) lr 3e-4 loss explode 10000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1000_upcr2_upcv1000_noisetrue_rbs1e5', # TODO(pu) loss explode 10000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu) loss explode 3000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_zrelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu) 80000iters eval rew_mean -278
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs2e4', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc100_upcr2_upcv0_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) debug 2m collect rew_max 200
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc100_upcr2_upcv0_targetnoise_nocollectnoise_rbs2e4', # TODO(pu) 2m collect rew_mean -120 不变
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_vaeupdatez_eins48_rvuc100_upcr2_upcv100_noisetrue_rbs2e4', # TODO(pu) 90000iters eval rew_mean -139
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr2_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 2m eval rew_mean -210 best
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr20_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 3m collect rew_mean -254 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr50_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 1m collect rew_mean -277 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr2_upcv1000_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 0.5m eval rew_mean -43 best
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr20_upcv1000_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 3m collect rew_mean -256 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr20_upcv1000_notargetnoise_nocollectnoise_rbs1e5', # TODO(pu) 3m collect rew_mean -259
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc10000_upcr2_upcv10000_notargetnoise_collectoriginalnoise_rbs5e5_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs1e5_rsc',# TODO(pu) run6 best
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns1280_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs1e5_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr50_upcv1000_notargetnoise_collectoriginalnoise_rbs1e5_rsc', # run4
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upc2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upc2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) run3
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc_lsc',# TODO(pu)
env=dict(
env_id='LunarLanderContinuous-v2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册