提交 18f86f26 编写于 作者: P puyuan1996 提交者: niuyazhe

test(pu): delete noise and change the data for updating vae

上级 5112584b
......@@ -131,10 +131,10 @@ def serial_pipeline_td3_vae(
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# rl phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,10):
if iter % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle-1):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,20):
if iter % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle):
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect_rl):
for i in range(cfg.policy.learn.update_per_collect_rl):#2->12
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
for item in train_data:
......@@ -151,11 +151,14 @@ def serial_pipeline_td3_vae(
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
# vae phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(10, 11):
if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1, cfg.policy.learn.rl_vae_update_circle):
for i in range(cfg.policy.learn.update_per_collect_vae):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(19, 20):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1, cfg.policy.learn.rl_vae_update_circle):
if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1,
cfg.policy.learn.rl_vae_update_circle):
for i in range(cfg.policy.learn.update_per_collect_vae):#40
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
train_data= train_data + new_data # TODO(pu)
for item in train_data:
item['rl_phase'] = False
item['vae_phase'] = True
......
......@@ -219,7 +219,10 @@ class TD3VAEPolicy(DDPGPolicy):
self._forward_learn_cnt = 0 # count iterations
# action_shape, obs_shape, action_latent_dim, hidden_size_list
# self._vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self._optimizer_vae = Adam(
self._vae_model.parameters(),
lr=self._cfg.learn.learning_rate_vae,
......@@ -254,7 +257,8 @@ class TD3VAEPolicy(DDPGPolicy):
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action mu
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result.pop(-1) # remove z
result[2] = data['action']
true_residual = data['next_obs'] - data['obs']
......@@ -307,7 +311,7 @@ class TD3VAEPolicy(DDPGPolicy):
)
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(10,15):
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_update_freq == 0:
if data['vae_phase'][0] is True:
if data['vae_phase'][0].item() is True:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if self._cuda:
data = to_device(data, self._device)
......@@ -319,7 +323,8 @@ class TD3VAEPolicy(DDPGPolicy):
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result.pop(-1) # remove z
result[2] = data['action']
true_residual = data['next_obs'] - data['obs']
......@@ -336,6 +341,7 @@ class TD3VAEPolicy(DDPGPolicy):
loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
loss_dict['kld_loss'] = vae_loss['kld_loss']
loss_dict['predict_loss'] = vae_loss['predict_loss']
# vae update
self._optimizer_vae.zero_grad()
......@@ -597,7 +603,7 @@ class TD3VAEPolicy(DDPGPolicy):
"""
ret = [
'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin',
'action', 'td_error', 'vae_loss', 'reconstruction_loss', 'kld_loss'
'action', 'td_error', 'vae_loss', 'reconstruction_loss', 'kld_loss', 'predict_loss'
]
if self._twin_critic:
ret += ['critic_twin_loss']
......
......@@ -2,10 +2,18 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc11_upcv10',
exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc11_upcv20',
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc21_upcv40',
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc3_upcv4',
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc2_upcv4', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc10_upcv20', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc20_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc30_upcv60', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcv150', # TODO(pu) eval reward_mean -132.63不变
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcr12_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv40', # TODO(pu)
exp_name='lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv40_noisefalse', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv150', # TODO(pu)
env=dict(
env_id='LunarLanderContinuous-v2',
......@@ -21,8 +29,8 @@ lunarlander_td3vae_config = dict(
policy=dict(
cuda=False,
priority=False,
# random_collect_size=1280,
random_collect_size=0,
random_collect_size=12800,
# random_collect_size=0,
original_action_shape=2,
model=dict(
obs_shape=8,
......@@ -31,12 +39,15 @@ lunarlander_td3vae_config = dict(
actor_head_type='regression',
),
learn=dict(
warm_up_update=0,
# warm_up_update=100,
rl_vae_update_circle=11, # train rl 10 iter, vae 1 iter
# warm_up_update=0,
warm_up_update=1000,
# vae_train_times_per_update=1, # TODO(pu)
update_per_collect_rl=2,
update_per_collect_vae=20,
rl_vae_update_circle=20, # train rl 20 iter, vae 1 iter
# update_per_collect_rl=2,
update_per_collect_rl=12,
update_per_collect_vae=40,
# update_per_collect_vae=150,
batch_size=128,
learning_rate_actor=3e-4,
......@@ -44,7 +55,8 @@ lunarlander_td3vae_config = dict(
learning_rate_vae=3e-4,
ignore_done=False, # TODO(pu)
actor_update_freq=2,
noise=True,
# noise=True,
noise=False, # TODO(pu)
noise_sigma=0.1,
noise_range=dict(
min=-0.5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册