提交 9e6de548 编写于 作者: P puyuan1996

polish(pu):polish td3_vae config

上级 b65eb2d4
......@@ -269,7 +269,7 @@ class TD3VAEPolicy(DDPGPolicy):
result['original_action'] = data['action']
result['true_residual'] = data['next_obs'] - data['obs']
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=1) # TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
......@@ -338,7 +338,7 @@ class TD3VAEPolicy(DDPGPolicy):
# self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.02), :] # values, indices
# self.c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.98), :]
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=1) # TODO(pu):weight
loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
......
......@@ -3,7 +3,7 @@ from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc',
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc_lsc',# TODO(pu)
env=dict(
......@@ -40,13 +40,13 @@ lunarlander_td3vae_config = dict(
# rl_vae_update_circle=1,
# update_per_collect_rl=50,
update_per_collect_rl=20,
# update_per_collect_rl=2,
# update_per_collect_rl=20,
update_per_collect_rl=2,
update_per_collect_vae=100,
# update_per_collect_vae=100,
# update_per_collect_vae=20,
# update_per_collect_vae=1,
# update_per_collect_vae=0,
update_per_collect_vae=0,
batch_size=128,
learning_rate_actor=1e-3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册