提交 4fca50f4 编写于 作者: P puyuan1996 提交者: niuyazhe

polish(pu): vae and rl update alternately

上级 4d240686
......@@ -46,9 +46,9 @@ class VanillaVAE(BaseVAE):
**kwargs) -> None:
super(VanillaVAE, self).__init__()
self.latent_dim = latent_dim
self.action_dim = in_channels_1
self.obs_dim = in_channels_2
self.latent_dim = latent_dim
self.hidden_dims = hidden_dims
modules = []
......
......@@ -217,14 +217,14 @@ class TD3VAEPolicy(DDPGPolicy):
self._target_model.reset()
self._forward_learn_cnt = 0 # count iterations
self._vae_model = VanillaVAE(2, 8, 64, [256, 256, 256]) # action_shape, latent_dim, hidden_size_list
# action_shape, obs_shape, action_latent_dim, hidden_size_list
self._vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self._optimizer_vae = Adam(
self._vae_model.parameters(),
lr=self._cfg.learn.learning_rate_vae,
)
# self.vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.obs_shape, self._cfg.model.action_shape, [256, 256])
# action_shape, self.state_dim latent_dim, hidden_size_list
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
......@@ -235,6 +235,7 @@ class TD3VAEPolicy(DDPGPolicy):
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
"""
# warmup phase
if 'warm_up' in data[0].keys() and data[0]['warm_up'] is True:
loss_dict = {}
data = default_preprocess_learn(
......@@ -295,7 +296,9 @@ class TD3VAEPolicy(DDPGPolicy):
**q_value_dict,
}
else:
self._forward_learn_cnt += 1
loss_dict = {}
q_value_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._cfg.priority,
......@@ -303,124 +306,139 @@ class TD3VAEPolicy(DDPGPolicy):
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# train vae
# ====================
result = self._vae_model(
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action
result.pop(-1) # remove z
result[2] = data['action']
true_residual = data['next_obs'] - data['obs']
result = result + [true_residual]
vae_loss = self._vae_model.loss_function(*result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
loss_dict['kld_loss'] = vae_loss['kld_loss']
# vae update
self._optimizer_vae.zero_grad()
vae_loss['loss'].backward()
self._optimizer_vae.step()
# ====================
# critic learn forward
# ====================
self._learn_model.train()
self._target_model.train()
next_obs = data['next_obs']
reward = data['reward']
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward({'obs': data['obs'],'action': data['latent_action']}, mode='compute_critic')['q_value']
q_value_dict = {}
if self._twin_critic:
q_value_dict['q_value'] = q_value[0].mean()
q_value_dict['q_value_twin'] = q_value[1].mean()
else:
q_value_dict['q_value'] = q_value.mean()
# target q value.
with torch.no_grad():
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') # latent action
next_actor_data['obs'] = next_obs
target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
if self._twin_critic:
# TD3: two critic networks
target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
# critic network1
td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# critic network2(twin network)
td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
loss_dict['critic_twin_loss'] = critic_twin_loss
td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
else:
# DDPG: single critic network
td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# ================
# critic update
# ================
self._optimizer_critic.zero_grad()
for k in loss_dict:
if 'critic' in k:
loss_dict[k].backward()
self._optimizer_critic.step()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') # latent action
actor_data['obs'] = data['obs']
if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_update_freq in [0,1,2,3,4]:
for i in range(self._cfg.learn.train_vae_times_per_update):
if self._cuda:
data = to_device(data, self._device)
# ====================
# train vae
# ====================
result = self._vae_model(
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action
result.pop(-1) # remove z
result[2] = data['action']
true_residual = data['next_obs'] - data['obs']
result = result + [true_residual]
vae_loss = self._vae_model.loss_function(*result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
loss_dict['kld_loss'] = vae_loss['kld_loss']
# vae update
self._optimizer_vae.zero_grad()
vae_loss['loss'].backward()
self._optimizer_vae.step()
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': torch.Tensor([0]).item(),
'priority': torch.Tensor([0]).item(),
'td_error': torch.Tensor([0]).item(),
**loss_dict,
**q_value_dict,
}
if (self._forward_learn_cnt + 1) % self._cfg.learn.rl_update_freq in [5,6,7,8,9]:
# ====================
# critic learn forward
# ====================
self._learn_model.train()
self._target_model.train()
next_obs = data['next_obs']
reward = data['reward']
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward({'obs': data['obs'], 'action': data['latent_action']}, mode='compute_critic')['q_value']
q_value_dict = {}
if self._twin_critic:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
q_value_dict['q_value'] = q_value[0].mean()
q_value_dict['q_value_twin'] = q_value[1].mean()
else:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
loss_dict['actor_loss'] = actor_loss
# actor update
self._optimizer_actor.zero_grad()
actor_loss.backward()
self._optimizer_actor.step()
# =============
# after update
# =============
loss_dict['total_loss'] = sum(loss_dict.values())
self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
if self._cfg.action_space == 'hybrid':
action_log_value = -1. # TODO(nyz) better way to viz hybrid action
else:
action_log_value = data['action'].mean()
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': action_log_value,
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(),
**loss_dict,
**q_value_dict,
}
q_value_dict['q_value'] = q_value.mean()
# target q value.
with torch.no_grad():
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') # latent action
next_actor_data['obs'] = next_obs
target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
if self._twin_critic:
# TD3: two critic networks
target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
# critic network1
td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# critic network2(twin network)
td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
loss_dict['critic_twin_loss'] = critic_twin_loss
td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
else:
# DDPG: single critic network
td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# ================
# critic update
# ================
self._optimizer_critic.zero_grad()
for k in loss_dict:
if 'critic' in k:
loss_dict[k].backward()
self._optimizer_critic.step()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') # latent action
actor_data['obs'] = data['obs']
if self._twin_critic:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
else:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
loss_dict['actor_loss'] = actor_loss
# actor update
self._optimizer_actor.zero_grad()
actor_loss.backward()
self._optimizer_actor.step()
# =============
# after update
# =============
loss_dict['total_loss'] = sum(loss_dict.values())
# self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
if self._cfg.action_space == 'hybrid':
action_log_value = -1. # TODO(nyz) better way to viz hybrid action
else:
action_log_value = data['action'].mean()
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': action_log_value,
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(),
**loss_dict,
**q_value_dict,
}
def _state_dict_learn(self) -> Dict[str, Any]:
return {
......
......@@ -2,7 +2,8 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
exp_name='lunarlander_cont_td3_vae',
# exp_name='lunarlander_cont_td3_vae_wu0_vae5rl5_tvtpc1',
exp_name='lunarlander_cont_td3_vae_wu0_vae5rl5_tvtpc5',
env=dict(
env_id='LunarLanderContinuous-v2',
# collector_env_num=8,
......@@ -17,18 +18,24 @@ lunarlander_td3vae_config = dict(
policy=dict(
cuda=False,
priority=False,
random_collect_size=12800,
# random_collect_size=1280,
random_collect_size=0,
original_action_shape=2,
model=dict(
obs_shape=8,
action_shape=64, # latent_action_shape
action_shape=2, # 64, # action_latent_shape
twin_critic=True,
actor_head_type='regression',
),
learn=dict(
# warm_up_update=1,
warm_up_update=1000,
update_per_collect=2,
warm_up_update=0,
# warm_up_update=100,
vae_update_freq=10, # TODO(pu)
rl_update_freq=10,
train_vae_times_per_update=5, # TODO(pu)
update_per_collect=10, # train vae 5 times, rl 5 times
batch_size=128,
learning_rate_actor=0.001,
learning_rate_critic=0.001,
......@@ -44,7 +51,9 @@ lunarlander_td3vae_config = dict(
),
collect=dict(
# n_sample=48,
each_iter_n_sample=48,
# each_iter_n_sample=48,
# each_iter_n_sample=128,
each_iter_n_sample=256,
noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ),
),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册