...
 
Commits (7)
    https://gitcode.net/opendilab/DI-engine/-/commit/9e6de54883c90d77d84f86fc07eebb77200ee54b polish(pu):polish td3_vae config 2021-12-22T10:38:06+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/9dd84dd319720a6fdce53d4f43429aca32cae5ab polish(pu): polish vae structure, use add not concat between the embeddings o... 2021-12-23T20:36:37+08:00 puyuan1996 2402552459@qq.com polish(pu): polish vae structure, use add not concat between the embeddings of obs and action, use tanh after sample z and after the reconstruction_action head https://gitcode.net/opendilab/DI-engine/-/commit/3f7e2130e1e0ff38cb213f92f31cb36a9aa01098 polish(pu):polish kl weight and prediction weight 2021-12-24T18:11:04+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/c7d85c97bd5daba2a4d232a97983017be8857767 polish(pu):polish td3_vae using the best setting 2021-12-26T15:25:57+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/96ea36240859d9104d69808be205b9c11b2802c4 style(pu): yapf format 2021-12-26T15:59:17+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/6ca776427043da8f4e8ef59c6bc98c4fe572a0fd polish(pu):polish config 2021-12-26T16:01:43+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/70328aabcd696d3c7691fb8beb3394198a062420 fix(pu): fix bug when collector_env_num>1, the self._traj_buffer is not empty... 2021-12-26T19:09:44+08:00 puyuan1996 2402552459@qq.com fix(pu): fix bug when collector_env_num>1, the self._traj_buffer is not empty and will leave over the data in random collect phase
......@@ -13,6 +13,7 @@ from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed
import copy
def serial_pipeline_td3_vae(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
......@@ -92,13 +93,13 @@ def serial_pipeline_td3_vae(
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
for item in new_data:
item['warm_up'] = True
replay_buffer_recent.push(new_data, cur_collector_envstep=0)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
### warm_up ###
# warm_up
# Learn policy from collected data
for i in range(cfg.policy.learn.warm_up_update):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer_recent.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
......@@ -109,8 +110,13 @@ def serial_pipeline_td3_vae(
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer_recent.update(learner.priority_info)
replay_buffer_recent.clear() # TODO(pu)
replay_buffer.update(learner.priority_info)
replay_buffer.clear() # TODO(pu): NOTE
# NOTE: for the case collector_env_num>1, because after the random collect phase, self._traj_buffer[env_id] may be not empty. Only
# if the condition "timestep.done or len(self._traj_buffer[env_id]) == self._traj_len" is satisfied, the self._traj_buffer will be clear.
# For our alg., the data in self._traj_buffer[env_id], latent_action=False, cannot be used in rl_vae phase.
collector.reset(policy.collect_mode)
for iter in range(max_iterations):
collect_kwargs = commander.step()
......@@ -120,7 +126,7 @@ def serial_pipeline_td3_vae(
if stop:
break
# Collect data by default config n_sample/n_episode
if hasattr(cfg.policy.collect, "each_iter_n_sample"): # TODO(pu)
if hasattr(cfg.policy.collect, "each_iter_n_sample"):
new_data = collector.collect(
n_sample=cfg.policy.collect.each_iter_n_sample,
train_iter=learner.train_iter,
......@@ -134,11 +140,9 @@ def serial_pipeline_td3_vae(
replay_buffer_recent.push(copy.deepcopy(new_data), cur_collector_envstep=collector.envstep)
# rl phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,20):
if iter % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle):
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect_rl):
# print('update_per_collect_rl')
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None:
......@@ -157,15 +161,17 @@ def serial_pipeline_td3_vae(
replay_buffer.update(learner.priority_info)
# vae phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(19, 20):
if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1, cfg.policy.learn.rl_vae_update_circle):
if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1,
cfg.policy.learn.rl_vae_update_circle):
for i in range(cfg.policy.learn.update_per_collect_vae):
# print('update_per_collect_vae')
# Learner will train ``update_per_collect`` times in one iteration.
train_data_history = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
# train_data_recent = replay_buffer_recent.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) # TODO(pu)
# train_data = train_data_history + train_data_recent # TODO(pu)
train_data = train_data_history # TODO(pu)
train_data_history = replay_buffer.sample(
int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
)
train_data_recent = replay_buffer_recent.sample(
int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
)
train_data = train_data_history + train_data_recent # TODO(pu):
if train_data is not None:
for item in train_data:
......@@ -181,7 +187,7 @@ def serial_pipeline_td3_vae(
learner.train(train_data, collector.envstep)
# if learner.policy.get_attribute('priority'):
# replay_buffer.update(learner.priority_info)
# replay_buffer_recent.clear() # TODO(pu)
replay_buffer_recent.clear() # TODO(pu)
# Learner's after_run hook.
learner.call_hook('after_run')
......
......@@ -38,12 +38,7 @@ class BaseVAE(nn.Module):
class VanillaVAE(BaseVAE):
def __init__(self,
action_dim: int,
obs_dim: int,
latent_dim: int,
hidden_dims: List = None,
**kwargs) -> None:
def __init__(self, action_dim: int, obs_dim: int, latent_dim: int, hidden_dims: List = None, **kwargs) -> None:
super(VanillaVAE, self).__init__()
self.action_dim = action_dim
......@@ -53,83 +48,30 @@ class VanillaVAE(BaseVAE):
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
hidden_dims = [256]
# Build Encoder
# action
self.action_head = nn.Sequential(
nn.Linear(self.action_dim, hidden_dims[0]),
nn.ReLU())
self.action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU())
# obs
self.obs_head = nn.Sequential(
nn.Linear(self.obs_dim, hidden_dims[0]),
nn.ReLU())
in_dim = hidden_dims[0] + hidden_dims[0]
for h_dim in hidden_dims[1:-1]:
# modules.append(
# nn.Sequential(
# nn.Conv2d(in_channels, out_channels=h_dim,
# kernel_size=3, stride=2, padding=1),
# nn.BatchNorm2d(h_dim),
# nn.LeakyReLU())
# )
# in_channels = h_dim
modules.append(
nn.Sequential(
nn.Linear(in_dim, h_dim),
nn.ReLU())
)
in_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
self.obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.encoder = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.mu_head = nn.Linear(hidden_dims[0], latent_dim)
self.var_head = nn.Linear(hidden_dims[0], latent_dim)
# Build Decoder
modules = []
hidden_dims.reverse()
# for i in range(len(hidden_dims) - 1):
# modules.append(
# nn.Sequential(
# nn.ConvTranspose2d(hidden_dims[i],
# hidden_dims[i + 1],
# kernel_size=3,
# stride=2,
# padding=1,
# output_padding=1),
# nn.BatchNorm2d(hidden_dims[i + 1]),
# nn.LeakyReLU())
# )
in_dim = self.latent_dim + hidden_dims[0]
for h_dim in hidden_dims[1:-1]:
modules.append(
nn.Sequential(
nn.Linear(in_dim, h_dim),
nn.ReLU())
)
in_dim = h_dim
self.decoder = nn.Sequential(*modules)
# self.reconstruction_layer = nn.Linear(hidden_dims[-1], self.action_dim) # TODO(pu)
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[-1], self.action_dim), nn.Tanh())
self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
self.decoder_action = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
self.decoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
# TODO(pu): tanh
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim)
# residual prediction
self.prediction_layer_1 = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-1]), nn.ReLU())
self.prediction_layer_2 = nn.Linear(hidden_dims[-1], self.obs_dim)
# self.final_layer = nn.Sequential(
# nn.ConvTranspose2d(hidden_dims[-1],
# hidden_dims[-1],
# kernel_size=3,
# stride=2,
# padding=1,
# output_padding=1),
# nn.BatchNorm2d(hidden_dims[-1]),
# nn.LeakyReLU(),
# nn.Conv2d(hidden_dims[-1], out_channels=3,
# kernel_size=3, padding=1),
# nn.Tanh())
self.prediction_head_1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.prediction_head_2 = nn.Linear(hidden_dims[0], self.obs_dim)
self.obs_encoding = None
......@@ -142,15 +84,20 @@ class VanillaVAE(BaseVAE):
"""
action_encoding = self.action_head(input['action'])
obs_encoding = self.obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
self.obs_encoding = obs_encoding
input = torch.cat([obs_encoding, action_encoding], dim=-1)
# input = torch.cat([obs_encoding, action_encoding], dim=-1)
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
mu = self.mu_head(result)
log_var = self.var_head(result)
return [mu, log_var]
......@@ -161,15 +108,14 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
# for compatiability collect and eval in fist iteration
if self.obs_encoding is None or z.shape[:-1] != self.obs_encoding.shape[:-1]:
self.obs_encoding = torch.zeros(list(z.shape[:-1]) + [self.hidden_dims[1]])
action_decoding = self.decoder_action(torch.tanh(z)) # NOTE: tanh, here z is not bounded
# action_decoding = self.decoder_action(z) # NOTE: tanh, here z is not bounded
action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
input = torch.cat([self.obs_encoding, torch.tanh(z)], dim=-1) # TODO(pu): here z is not bounded
decoding_tmp = self.decoder(input)
reconstruction_action = self.reconstruction_layer(decoding_tmp)
predition_residual_tmp = self.prediction_layer_1(decoding_tmp)
predition_residual = self.prediction_layer_2(predition_residual_tmp)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp)
return [reconstruction_action, predition_residual]
......@@ -181,11 +127,13 @@ class VanillaVAE(BaseVAE):
:return: (Tensor) [B x C x H x W]
"""
self.obs_encoding = self.obs_head(obs)
input = torch.cat([self.obs_encoding, z], dim=-1) # TODO(pu): here z is already bounded
decoding_tmp = self.decoder(input)
reconstruction_action = self.reconstruction_layer(decoding_tmp)
predition_residual_tmp = self.prediction_layer_1(decoding_tmp)
predition_residual = self.prediction_layer_2(predition_residual_tmp)
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
action_decoding = self.decoder_action(z)
action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp)
return [reconstruction_action, predition_residual]
......@@ -204,11 +152,16 @@ class VanillaVAE(BaseVAE):
def forward(self, input: Tensor, **kwargs) -> dict:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return {'recons_action': self.decode(z)[0], 'prediction_residual': self.decode(z)[1], 'input': input, 'mu': mu, 'log_var': log_var, 'z': z} # recons_action, prediction_residual
def loss_function(self,
args,
**kwargs) -> dict:
return {
'recons_action': self.decode(z)[0],
'prediction_residual': self.decode(z)[1],
'input': input,
'mu': mu,
'log_var': log_var,
'z': z
}
def loss_function(self, args, **kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
......@@ -227,16 +180,13 @@ class VanillaVAE(BaseVAE):
predict_weight = kwargs['predict_weight']
recons_loss = F.mse_loss(recons_action, original_action)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
predict_loss = F.mse_loss(prediction_residual, true_residual)
loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss
return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}
def sample(self,
num_samples: int,
current_device: int, **kwargs) -> Tensor:
def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
......@@ -244,11 +194,8 @@ class VanillaVAE(BaseVAE):
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = torch.randn(num_samples, self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
......
......@@ -10,7 +10,6 @@ from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
from ding.utils import POLICY_REGISTRY
from .ddpg import DDPGPolicy
from ding.model.template.vae import VanillaVAE
from ding.utils import RunningMeanStd
......@@ -220,18 +219,18 @@ class TD3VAEPolicy(DDPGPolicy):
self._target_model.reset()
self._forward_learn_cnt = 0 # count iterations
# action_shape, obs_shape, action_latent_dim, hidden_size_list
# self._vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256, 256])
# action_shape, obs_shape, latent_action_dim, hidden_size_list
self._vae_model = VanillaVAE(
self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256]
)
# self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self._optimizer_vae = Adam(
self._vae_model.parameters(),
lr=self._cfg.learn.learning_rate_vae,
)
self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4)
self.c_percentage_bound_lower = -1*torch.ones([6])
self.c_percentage_bound_lower = -1 * torch.ones([6])
self.c_percentage_bound_upper = torch.ones([6])
def _forward_learn(self, data: dict) -> Dict[str, Any]:
......@@ -259,24 +258,13 @@ class TD3VAEPolicy(DDPGPolicy):
# ====================
# train vae
# ====================
result = self._vae_model(
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = result[5].detach() # TODO(pu): update latent_action mu
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result = self._vae_model({'action': data['action'], 'obs': data['obs']})
result['original_action'] = data['action']
result['true_residual'] = data['next_obs'] - data['obs']
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
vae_loss = self._vae_model.loss_function(result, kld_weight=0.01, predict_weight=0.01) # TODO(pu): weight
loss_dict['vae_loss'] = vae_loss['loss'].item()
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item()
loss_dict['kld_loss'] = vae_loss['kld_loss'].item()
......@@ -316,29 +304,31 @@ class TD3VAEPolicy(DDPGPolicy):
use_nstep=False
)
if data['vae_phase'][0].item() is True:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if self._cuda:
data = to_device(data, self._device)
# ====================
# train vae
# ====================
result = self._vae_model(
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = result['z'].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result['mu'].detach() # TODO(pu): update latent_action mu
result = self._vae_model({'action': data['action'], 'obs': data['obs']})
result['original_action'] = data['action']
result['true_residual'] = data['next_obs'] - data['obs']
# latent space constraint (LSC)
# data['latent_action'] = torch.tanh(result['z'].detach()) # TODO(pu): update latent_action z, shape (128,6)
# self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.02), :] # values, indices
# self.c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.98), :]
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight
# NOTE: using tanh is important, update latent_action using z, shape (128,6)
data['latent_action'] = torch.tanh(result['z'].clone().detach())
# data['latent_action'] = result['z'].clone().detach()
self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
result['recons_action'].shape[0] * 0.02
), :] # values, indices
self.c_percentage_bound_upper = data['latent_action'].sort(
dim=0
)[0][int(result['recons_action'].shape[0] * 0.98), :]
vae_loss = self._vae_model.loss_function(
result, kld_weight=0.01, predict_weight=0.01
) # TODO(pu): weight
loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
......@@ -375,22 +365,28 @@ class TD3VAEPolicy(DDPGPolicy):
# ====================
if self._cuda:
data = to_device(data, self._device)
result = self._vae_model(
{'action': data['action'],
'obs': data['obs']})
result = self._vae_model({'action': data['action'], 'obs': data['obs']})
true_residual = data['next_obs'] - data['obs']
# Representation shift correction (RSC)
for i in range(result['recons_action'].shape[0]):
if F.mse_loss(result['prediction_residual'][i], true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
data['latent_action'][i] = torch.tanh(result['z'][i].detach()) # TODO(pu): update latent_action z tanh
# data['latent_action'] = result['mu'].detach() # TODO(pu): update latent_action mu
if F.mse_loss(result['prediction_residual'][i],
true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
# NOTE: using tanh is important, update latent_action using z
data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach())
# data['latent_action'][i] = result['z'][i].clone().detach()
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward({'obs': data['obs'], 'action': data['latent_action']}, mode='compute_critic')['q_value']
q_value = self._learn_model.forward(
{
'obs': data['obs'],
'action': data['latent_action']
}, mode='compute_critic'
)['q_value']
q_value_dict = {}
if self._twin_critic:
q_value_dict['q_value'] = q_value[0].mean()
......@@ -399,7 +395,8 @@ class TD3VAEPolicy(DDPGPolicy):
q_value_dict['q_value'] = q_value.mean()
# target q value.
with torch.no_grad():
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') # latent action
# NOTE: here next_actor_data['action'] is latent action
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor')
next_actor_data['obs'] = next_obs
target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
if self._twin_critic:
......@@ -432,9 +429,8 @@ class TD3VAEPolicy(DDPGPolicy):
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') # latent action
# NOTE: actor_data['action] is latent action
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
actor_data['obs'] = data['obs']
if self._twin_critic:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
......@@ -460,7 +456,6 @@ class TD3VAEPolicy(DDPGPolicy):
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': action_log_value,
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(),
......@@ -527,27 +522,24 @@ class TD3VAEPolicy(DDPGPolicy):
output['latent_action'] = output['action']
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
for i in range(output['action'].shape[-1]):
output['action'][:, i].clamp_(
self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
)
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
# add noise in the original actions
# NOTE: add noise in the original actions
from ding.rl_utils.exploration import GaussianNoise
action = output['action']
gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1)
# gaussian_noise = GaussianNoise(mu=0.0, sigma=0.5)
noise = gaussian_noise( output['action'].shape, output['action'].device)
noise = gaussian_noise(output['action'].shape, output['action'].device)
if self._cfg.learn.noise_range is not None:
noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max'])
action += noise
self.action_range = {
'min': -1,
'max': 1
}
self.action_range = {'min': -1, 'max': 1}
if self.action_range is not None:
action = action.clamp(self.action_range['min'], self.action_range['max'])
output['action'] = action
......@@ -569,7 +561,6 @@ class TD3VAEPolicy(DDPGPolicy):
Return:
- transition (:obj:`Dict[str, Any]`): Dict type transition data.
"""
# if hasattr(model_output, 'latent_action'):
if 'latent_action' in model_output.keys():
transition = {
'obs': obs,
......@@ -584,7 +575,7 @@ class TD3VAEPolicy(DDPGPolicy):
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'latent_action': False,
'latent_action': 999,
'reward': timestep.reward,
'done': timestep.done,
}
......@@ -629,9 +620,10 @@ class TD3VAEPolicy(DDPGPolicy):
output['latent_action'] = output['action']
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
for i in range(output['action'].shape[-1]):
output['action'][:, i].clamp_(
self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
)
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
......
......@@ -227,6 +227,11 @@ class SampleSerialCollector(ISerialCollector):
if self._transform_obs:
obs = to_tensor(obs, dtype=torch.float32)
policy_output = self._policy.forward(obs, **policy_kwargs)
if 'latent_action' in policy_output [0].keys():
if policy_output [0]['latent_action'] is False:
print('here')
self._policy_output_pool.update(policy_output)
# Interact with env.
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
......
......@@ -2,7 +2,7 @@ from easydict import EasyDict
from ding.entry import serial_pipeline
lunarlander_td3_config = dict(
exp_name='lunarlander_cont_td3',
exp_name='lunarlander_cont_td3_ns256_upcr256_lr3e-4_rbs1e5',
env=dict(
env_id='LunarLanderContinuous-v2',
collector_env_num=8,
......@@ -15,7 +15,7 @@ lunarlander_td3_config = dict(
policy=dict(
cuda=False,
priority=False,
random_collect_size=800,
random_collect_size=0,
model=dict(
obs_shape=8,
action_shape=2,
......@@ -23,11 +23,11 @@ lunarlander_td3_config = dict(
actor_head_type='regression',
),
learn=dict(
update_per_collect=2,
update_per_collect=256,
batch_size=128,
learning_rate_actor=0.001,
learning_rate_critic=0.001,
ignore_done=False, # TODO(pu)
learning_rate_actor=3e-4,
learning_rate_critic=3e-4,
ignore_done=False,
actor_update_freq=2,
noise=True,
noise_sigma=0.1,
......@@ -37,12 +37,13 @@ lunarlander_td3_config = dict(
),
),
collect=dict(
n_sample=48,
n_sample=256,
noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
),
)
lunarlander_td3_config = EasyDict(lunarlander_td3_config)
......
......@@ -2,15 +2,10 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc',
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc_lsc',# TODO(pu)
exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc3_upcr256_upcv100_kw0.01_pw0.01_dot_tanh',
env=dict(
env_id='LunarLanderContinuous-v2',
# collector_env_num=8,
# evaluator_env_num=5,
collector_env_num=1,
collector_env_num=8,
evaluator_env_num=5,
# (bool) Scale output action into legal range.
act_scale=True,
......@@ -18,48 +13,29 @@ lunarlander_td3vae_config = dict(
stop_value=200,
),
policy=dict(
cuda=False,
cuda=True,
priority=False,
random_collect_size=12800,
# random_collect_size=0,
random_collect_size=10000,
original_action_shape=2,
model=dict(
obs_shape=8,
action_shape=6, # latent_action_dim
twin_critic=True,
actor_head_type='regression',
),
learn=dict(
# warm_up_update=0,
warm_up_update=1000,
# vae_train_times_per_update=1, # TODO(pu)
# rl_vae_update_circle=1000,
rl_vae_update_circle=100, # train rl 100 iter, vae 1 iter
# rl_vae_update_circle=1,
# update_per_collect_rl=50,
update_per_collect_rl=20,
# update_per_collect_rl=2,
warm_up_update=int(1e4),
rl_vae_update_circle=3, # train rl 3 iter, vae 1 iter
update_per_collect_rl=256,
update_per_collect_vae=100,
# update_per_collect_vae=20,
# update_per_collect_vae=1,
# update_per_collect_vae=0,
batch_size=128,
learning_rate_actor=1e-3,
learning_rate_critic=1e-3,
# learning_rate_actor=3e-4,
# learning_rate_critic=3e-4,
learning_rate_vae=3e-4,
ignore_done=False, # TODO(pu)
learning_rate_actor=3e-4,
learning_rate_critic=3e-4,
learning_rate_vae=1e-4,
ignore_done=False,
target_theta=0.005,
# actor_update_freq=2,
actor_update_freq=1,
# noise=True,
noise=False, # TODO(pu)
actor_update_freq=2,
noise=True,
noise_sigma=0.1,
noise_range=dict(
min=-0.5,
......@@ -67,18 +43,13 @@ lunarlander_td3vae_config = dict(
),
),
collect=dict(
# each_iter_n_sample=48, # 1280
n_sample=48,
unroll_len=1, # TODO(pu)
# noise_sigma=0.1,
noise_sigma=0, # TODO(pu)
n_sample=256,
unroll_len=1,
noise_sigma=0, # NOTE: add noise in original action in _forward_collect method of td3_vae policy
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=int(2e4), ), ),
# other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
# other=dict(replay_buffer=dict(replay_buffer_size=int(5e5), ), ),
other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
),
)
lunarlander_td3vae_config = EasyDict(lunarlander_td3vae_config)
......
from typing import Any, List, Union, Optional
import time
import gym
......@@ -6,6 +7,7 @@ from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY
from ding.envs.common import affine_transform
@ENV_REGISTRY.register('lunarlander')
......@@ -14,6 +16,7 @@ class LunarLanderEnv(BaseEnv):
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._init_flag = False
self._act_scale = cfg.act_scale
def reset(self) -> np.ndarray:
if not self._init_flag:
......@@ -49,6 +52,8 @@ class LunarLanderEnv(BaseEnv):
assert isinstance(action, np.ndarray), type(action)
if action.shape == (1, ):
action = action.squeeze() # 0-dim array
if self._act_scale:
action = affine_transform(action, min_val=-1, max_val=1)
obs, rew, done, info = self._env.step(action)
# self._env.render()
rew = float(rew)
......