...
 
Commits (7)
    https://gitcode.net/opendilab/DI-engine/-/commit/9e6de54883c90d77d84f86fc07eebb77200ee54b polish(pu):polish td3_vae config 2021-12-22T10:38:06+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/9dd84dd319720a6fdce53d4f43429aca32cae5ab polish(pu): polish vae structure, use add not concat between the embeddings o... 2021-12-23T20:36:37+08:00 puyuan1996 2402552459@qq.com polish(pu): polish vae structure, use add not concat between the embeddings of obs and action, use tanh after sample z and after the reconstruction_action head https://gitcode.net/opendilab/DI-engine/-/commit/3f7e2130e1e0ff38cb213f92f31cb36a9aa01098 polish(pu):polish kl weight and prediction weight 2021-12-24T18:11:04+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/c7d85c97bd5daba2a4d232a97983017be8857767 polish(pu):polish td3_vae using the best setting 2021-12-26T15:25:57+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/96ea36240859d9104d69808be205b9c11b2802c4 style(pu): yapf format 2021-12-26T15:59:17+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/6ca776427043da8f4e8ef59c6bc98c4fe572a0fd polish(pu):polish config 2021-12-26T16:01:43+08:00 puyuan1996 2402552459@qq.com https://gitcode.net/opendilab/DI-engine/-/commit/70328aabcd696d3c7691fb8beb3394198a062420 fix(pu): fix bug when collector_env_num>1, the self._traj_buffer is not empty... 2021-12-26T19:09:44+08:00 puyuan1996 2402552459@qq.com fix(pu): fix bug when collector_env_num>1, the self._traj_buffer is not empty and will leave over the data in random collect phase
...@@ -13,6 +13,7 @@ from ding.policy import create_policy, PolicyFactory ...@@ -13,6 +13,7 @@ from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed from ding.utils import set_pkg_seed
import copy import copy
def serial_pipeline_td3_vae( def serial_pipeline_td3_vae(
input_cfg: Union[str, Tuple[dict, dict]], input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0, seed: int = 0,
...@@ -92,13 +93,13 @@ def serial_pipeline_td3_vae( ...@@ -92,13 +93,13 @@ def serial_pipeline_td3_vae(
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
for item in new_data: for item in new_data:
item['warm_up'] = True item['warm_up'] = True
replay_buffer_recent.push(new_data, cur_collector_envstep=0) replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode) collector.reset_policy(policy.collect_mode)
### warm_up ### # warm_up
# Learn policy from collected data # Learn policy from collected data
for i in range(cfg.policy.learn.warm_up_update): for i in range(cfg.policy.learn.warm_up_update):
# Learner will train ``update_per_collect`` times in one iteration. # Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer_recent.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None: if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning( logging.warning(
...@@ -109,8 +110,13 @@ def serial_pipeline_td3_vae( ...@@ -109,8 +110,13 @@ def serial_pipeline_td3_vae(
learner.train(train_data, collector.envstep) learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'): if learner.policy.get_attribute('priority'):
replay_buffer_recent.update(learner.priority_info) replay_buffer.update(learner.priority_info)
replay_buffer_recent.clear() # TODO(pu) replay_buffer.clear() # TODO(pu): NOTE
# NOTE: for the case collector_env_num>1, because after the random collect phase, self._traj_buffer[env_id] may be not empty. Only
# if the condition "timestep.done or len(self._traj_buffer[env_id]) == self._traj_len" is satisfied, the self._traj_buffer will be clear.
# For our alg., the data in self._traj_buffer[env_id], latent_action=False, cannot be used in rl_vae phase.
collector.reset(policy.collect_mode)
for iter in range(max_iterations): for iter in range(max_iterations):
collect_kwargs = commander.step() collect_kwargs = commander.step()
...@@ -120,7 +126,7 @@ def serial_pipeline_td3_vae( ...@@ -120,7 +126,7 @@ def serial_pipeline_td3_vae(
if stop: if stop:
break break
# Collect data by default config n_sample/n_episode # Collect data by default config n_sample/n_episode
if hasattr(cfg.policy.collect, "each_iter_n_sample"): # TODO(pu) if hasattr(cfg.policy.collect, "each_iter_n_sample"):
new_data = collector.collect( new_data = collector.collect(
n_sample=cfg.policy.collect.each_iter_n_sample, n_sample=cfg.policy.collect.each_iter_n_sample,
train_iter=learner.train_iter, train_iter=learner.train_iter,
...@@ -134,11 +140,9 @@ def serial_pipeline_td3_vae( ...@@ -134,11 +140,9 @@ def serial_pipeline_td3_vae(
replay_buffer_recent.push(copy.deepcopy(new_data), cur_collector_envstep=collector.envstep) replay_buffer_recent.push(copy.deepcopy(new_data), cur_collector_envstep=collector.envstep)
# rl phase # rl phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,20):
if iter % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle): if iter % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle):
# Learn policy from collected data # Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect_rl): for i in range(cfg.policy.learn.update_per_collect_rl):
# print('update_per_collect_rl')
# Learner will train ``update_per_collect`` times in one iteration. # Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None: if train_data is not None:
...@@ -157,15 +161,17 @@ def serial_pipeline_td3_vae( ...@@ -157,15 +161,17 @@ def serial_pipeline_td3_vae(
replay_buffer.update(learner.priority_info) replay_buffer.update(learner.priority_info)
# vae phase # vae phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(19, 20): if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1,
if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1, cfg.policy.learn.rl_vae_update_circle): cfg.policy.learn.rl_vae_update_circle):
for i in range(cfg.policy.learn.update_per_collect_vae): for i in range(cfg.policy.learn.update_per_collect_vae):
# print('update_per_collect_vae')
# Learner will train ``update_per_collect`` times in one iteration. # Learner will train ``update_per_collect`` times in one iteration.
train_data_history = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) train_data_history = replay_buffer.sample(
# train_data_recent = replay_buffer_recent.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) # TODO(pu) int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
# train_data = train_data_history + train_data_recent # TODO(pu) )
train_data = train_data_history # TODO(pu) train_data_recent = replay_buffer_recent.sample(
int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
)
train_data = train_data_history + train_data_recent # TODO(pu):
if train_data is not None: if train_data is not None:
for item in train_data: for item in train_data:
...@@ -181,7 +187,7 @@ def serial_pipeline_td3_vae( ...@@ -181,7 +187,7 @@ def serial_pipeline_td3_vae(
learner.train(train_data, collector.envstep) learner.train(train_data, collector.envstep)
# if learner.policy.get_attribute('priority'): # if learner.policy.get_attribute('priority'):
# replay_buffer.update(learner.priority_info) # replay_buffer.update(learner.priority_info)
# replay_buffer_recent.clear() # TODO(pu) replay_buffer_recent.clear() # TODO(pu)
# Learner's after_run hook. # Learner's after_run hook.
learner.call_hook('after_run') learner.call_hook('after_run')
......
...@@ -38,12 +38,7 @@ class BaseVAE(nn.Module): ...@@ -38,12 +38,7 @@ class BaseVAE(nn.Module):
class VanillaVAE(BaseVAE): class VanillaVAE(BaseVAE):
def __init__(self, def __init__(self, action_dim: int, obs_dim: int, latent_dim: int, hidden_dims: List = None, **kwargs) -> None:
action_dim: int,
obs_dim: int,
latent_dim: int,
hidden_dims: List = None,
**kwargs) -> None:
super(VanillaVAE, self).__init__() super(VanillaVAE, self).__init__()
self.action_dim = action_dim self.action_dim = action_dim
...@@ -53,83 +48,30 @@ class VanillaVAE(BaseVAE): ...@@ -53,83 +48,30 @@ class VanillaVAE(BaseVAE):
modules = [] modules = []
if hidden_dims is None: if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512] hidden_dims = [256]
# Build Encoder # Build Encoder
# action # action
self.action_head = nn.Sequential( self.action_head = nn.Sequential(nn.Linear(self.action_dim, hidden_dims[0]), nn.ReLU())
nn.Linear(self.action_dim, hidden_dims[0]),
nn.ReLU())
# obs # obs
self.obs_head = nn.Sequential( self.obs_head = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
nn.Linear(self.obs_dim, hidden_dims[0]),
nn.ReLU()) self.encoder = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.mu_head = nn.Linear(hidden_dims[0], latent_dim)
in_dim = hidden_dims[0] + hidden_dims[0] self.var_head = nn.Linear(hidden_dims[0], latent_dim)
for h_dim in hidden_dims[1:-1]:
# modules.append(
# nn.Sequential(
# nn.Conv2d(in_channels, out_channels=h_dim,
# kernel_size=3, stride=2, padding=1),
# nn.BatchNorm2d(h_dim),
# nn.LeakyReLU())
# )
# in_channels = h_dim
modules.append(
nn.Sequential(
nn.Linear(in_dim, h_dim),
nn.ReLU())
)
in_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
# Build Decoder # Build Decoder
modules = [] self.condition_obs = nn.Sequential(nn.Linear(self.obs_dim, hidden_dims[0]), nn.ReLU())
hidden_dims.reverse() self.decoder_action = nn.Sequential(nn.Linear(latent_dim, hidden_dims[0]), nn.ReLU())
# for i in range(len(hidden_dims) - 1): self.decoder_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
# modules.append( # TODO(pu): tanh
# nn.Sequential( self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[0], self.action_dim), nn.Tanh())
# nn.ConvTranspose2d(hidden_dims[i], # self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim)
# hidden_dims[i + 1],
# kernel_size=3,
# stride=2,
# padding=1,
# output_padding=1),
# nn.BatchNorm2d(hidden_dims[i + 1]),
# nn.LeakyReLU())
# )
in_dim = self.latent_dim + hidden_dims[0]
for h_dim in hidden_dims[1:-1]:
modules.append(
nn.Sequential(
nn.Linear(in_dim, h_dim),
nn.ReLU())
)
in_dim = h_dim
self.decoder = nn.Sequential(*modules)
# self.reconstruction_layer = nn.Linear(hidden_dims[-1], self.action_dim) # TODO(pu)
self.reconstruction_layer = nn.Sequential(nn.Linear(hidden_dims[-1], self.action_dim), nn.Tanh())
# residual prediction # residual prediction
self.prediction_layer_1 = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-1]), nn.ReLU()) self.prediction_head_1 = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[0]), nn.ReLU())
self.prediction_layer_2 = nn.Linear(hidden_dims[-1], self.obs_dim) self.prediction_head_2 = nn.Linear(hidden_dims[0], self.obs_dim)
# self.final_layer = nn.Sequential(
# nn.ConvTranspose2d(hidden_dims[-1],
# hidden_dims[-1],
# kernel_size=3,
# stride=2,
# padding=1,
# output_padding=1),
# nn.BatchNorm2d(hidden_dims[-1]),
# nn.LeakyReLU(),
# nn.Conv2d(hidden_dims[-1], out_channels=3,
# kernel_size=3, padding=1),
# nn.Tanh())
self.obs_encoding = None self.obs_encoding = None
...@@ -142,15 +84,20 @@ class VanillaVAE(BaseVAE): ...@@ -142,15 +84,20 @@ class VanillaVAE(BaseVAE):
""" """
action_encoding = self.action_head(input['action']) action_encoding = self.action_head(input['action'])
obs_encoding = self.obs_head(input['obs']) obs_encoding = self.obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
self.obs_encoding = obs_encoding self.obs_encoding = obs_encoding
input = torch.cat([obs_encoding, action_encoding], dim=-1) # input = torch.cat([obs_encoding, action_encoding], dim=-1)
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding
result = self.encoder(input) result = self.encoder(input)
result = torch.flatten(result, start_dim=1) result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components # Split the result into mu and var components
# of the latent Gaussian distribution # of the latent Gaussian distribution
mu = self.fc_mu(result) mu = self.mu_head(result)
log_var = self.fc_var(result) log_var = self.var_head(result)
return [mu, log_var] return [mu, log_var]
...@@ -161,15 +108,14 @@ class VanillaVAE(BaseVAE): ...@@ -161,15 +108,14 @@ class VanillaVAE(BaseVAE):
:param z: (Tensor) [B x D] :param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W]
""" """
# for compatiability collect and eval in fist iteration action_decoding = self.decoder_action(torch.tanh(z)) # NOTE: tanh, here z is not bounded
if self.obs_encoding is None or z.shape[:-1] != self.obs_encoding.shape[:-1]: # action_decoding = self.decoder_action(z) # NOTE: tanh, here z is not bounded
self.obs_encoding = torch.zeros(list(z.shape[:-1]) + [self.hidden_dims[1]]) action_obs_decoding = action_decoding * self.obs_encoding
action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
input = torch.cat([self.obs_encoding, torch.tanh(z)], dim=-1) # TODO(pu): here z is not bounded reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
decoding_tmp = self.decoder(input) predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
reconstruction_action = self.reconstruction_layer(decoding_tmp) predition_residual = self.prediction_head_2(predition_residual_tmp)
predition_residual_tmp = self.prediction_layer_1(decoding_tmp)
predition_residual = self.prediction_layer_2(predition_residual_tmp)
return [reconstruction_action, predition_residual] return [reconstruction_action, predition_residual]
...@@ -181,11 +127,13 @@ class VanillaVAE(BaseVAE): ...@@ -181,11 +127,13 @@ class VanillaVAE(BaseVAE):
:return: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W]
""" """
self.obs_encoding = self.obs_head(obs) self.obs_encoding = self.obs_head(obs)
input = torch.cat([self.obs_encoding, z], dim=-1) # TODO(pu): here z is already bounded # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
decoding_tmp = self.decoder(input) action_decoding = self.decoder_action(z)
reconstruction_action = self.reconstruction_layer(decoding_tmp) action_obs_decoding = action_decoding * self.obs_encoding
predition_residual_tmp = self.prediction_layer_1(decoding_tmp) action_obs_decoding_tmp = self.decoder_common(action_obs_decoding)
predition_residual = self.prediction_layer_2(predition_residual_tmp) reconstruction_action = self.reconstruction_layer(action_obs_decoding_tmp)
predition_residual_tmp = self.prediction_head_1(action_obs_decoding_tmp)
predition_residual = self.prediction_head_2(predition_residual_tmp)
return [reconstruction_action, predition_residual] return [reconstruction_action, predition_residual]
...@@ -204,11 +152,16 @@ class VanillaVAE(BaseVAE): ...@@ -204,11 +152,16 @@ class VanillaVAE(BaseVAE):
def forward(self, input: Tensor, **kwargs) -> dict: def forward(self, input: Tensor, **kwargs) -> dict:
mu, log_var = self.encode(input) mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var) z = self.reparameterize(mu, log_var)
return {'recons_action': self.decode(z)[0], 'prediction_residual': self.decode(z)[1], 'input': input, 'mu': mu, 'log_var': log_var, 'z': z} # recons_action, prediction_residual return {
'recons_action': self.decode(z)[0],
def loss_function(self, 'prediction_residual': self.decode(z)[1],
args, 'input': input,
**kwargs) -> dict: 'mu': mu,
'log_var': log_var,
'z': z
}
def loss_function(self, args, **kwargs) -> dict:
""" """
Computes the VAE loss function. Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
...@@ -227,16 +180,13 @@ class VanillaVAE(BaseVAE): ...@@ -227,16 +180,13 @@ class VanillaVAE(BaseVAE):
predict_weight = kwargs['predict_weight'] predict_weight = kwargs['predict_weight']
recons_loss = F.mse_loss(recons_action, original_action) recons_loss = F.mse_loss(recons_action, original_action)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
predict_loss = F.mse_loss(prediction_residual, true_residual) predict_loss = F.mse_loss(prediction_residual, true_residual)
loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss
return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss} return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}
def sample(self, def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
num_samples: int,
current_device: int, **kwargs) -> Tensor:
""" """
Samples from the latent space and return the corresponding Samples from the latent space and return the corresponding
image space map. image space map.
...@@ -244,11 +194,8 @@ class VanillaVAE(BaseVAE): ...@@ -244,11 +194,8 @@ class VanillaVAE(BaseVAE):
:param current_device: (Int) Device to run the model :param current_device: (Int) Device to run the model
:return: (Tensor) :return: (Tensor)
""" """
z = torch.randn(num_samples, z = torch.randn(num_samples, self.latent_dim)
self.latent_dim)
z = z.to(current_device) z = z.to(current_device)
samples = self.decode(z) samples = self.decode(z)
return samples return samples
......
...@@ -10,7 +10,6 @@ from ding.utils import POLICY_REGISTRY ...@@ -10,7 +10,6 @@ from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy from .base_policy import Policy
from .common_utils import default_preprocess_learn from .common_utils import default_preprocess_learn
from ding.utils import POLICY_REGISTRY
from .ddpg import DDPGPolicy from .ddpg import DDPGPolicy
from ding.model.template.vae import VanillaVAE from ding.model.template.vae import VanillaVAE
from ding.utils import RunningMeanStd from ding.utils import RunningMeanStd
...@@ -220,18 +219,18 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -220,18 +219,18 @@ class TD3VAEPolicy(DDPGPolicy):
self._target_model.reset() self._target_model.reset()
self._forward_learn_cnt = 0 # count iterations self._forward_learn_cnt = 0 # count iterations
# action_shape, obs_shape, action_latent_dim, hidden_size_list # action_shape, obs_shape, latent_action_dim, hidden_size_list
# self._vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256, 256]) self._vae_model = VanillaVAE(
self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256]
)
# self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256]) # self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self._optimizer_vae = Adam( self._optimizer_vae = Adam(
self._vae_model.parameters(), self._vae_model.parameters(),
lr=self._cfg.learn.learning_rate_vae, lr=self._cfg.learn.learning_rate_vae,
) )
self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4) self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4)
self.c_percentage_bound_lower = -1*torch.ones([6]) self.c_percentage_bound_lower = -1 * torch.ones([6])
self.c_percentage_bound_upper = torch.ones([6]) self.c_percentage_bound_upper = torch.ones([6])
def _forward_learn(self, data: dict) -> Dict[str, Any]: def _forward_learn(self, data: dict) -> Dict[str, Any]:
...@@ -259,24 +258,13 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -259,24 +258,13 @@ class TD3VAEPolicy(DDPGPolicy):
# ==================== # ====================
# train vae # train vae
# ==================== # ====================
result = self._vae_model( result = self._vae_model({'action': data['action'], 'obs': data['obs']})
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = result[5].detach() # TODO(pu): update latent_action mu
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result['original_action'] = data['action'] result['original_action'] = data['action']
result['true_residual'] = data['next_obs'] - data['obs'] result['true_residual'] = data['next_obs'] - data['obs']
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight vae_loss = self._vae_model.loss_function(result, kld_weight=0.01, predict_weight=0.01) # TODO(pu): weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
loss_dict['vae_loss'] = vae_loss['loss'].item() loss_dict['vae_loss'] = vae_loss['loss'].item()
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item() loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item()
loss_dict['kld_loss'] = vae_loss['kld_loss'].item() loss_dict['kld_loss'] = vae_loss['kld_loss'].item()
...@@ -316,29 +304,31 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -316,29 +304,31 @@ class TD3VAEPolicy(DDPGPolicy):
use_nstep=False use_nstep=False
) )
if data['vae_phase'][0].item() is True: if data['vae_phase'][0].item() is True:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if self._cuda: if self._cuda:
data = to_device(data, self._device) data = to_device(data, self._device)
# ==================== # ====================
# train vae # train vae
# ==================== # ====================
result = self._vae_model( result = self._vae_model({'action': data['action'], 'obs': data['obs']})
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = result['z'].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result['mu'].detach() # TODO(pu): update latent_action mu
result['original_action'] = data['action'] result['original_action'] = data['action']
result['true_residual'] = data['next_obs'] - data['obs'] result['true_residual'] = data['next_obs'] - data['obs']
# latent space constraint (LSC) # latent space constraint (LSC)
# data['latent_action'] = torch.tanh(result['z'].detach()) # TODO(pu): update latent_action z, shape (128,6) # NOTE: using tanh is important, update latent_action using z, shape (128,6)
# self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.02), :] # values, indices data['latent_action'] = torch.tanh(result['z'].clone().detach())
# self.c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(result['recons_action'].shape[0] * 0.98), :] # data['latent_action'] = result['z'].clone().detach()
self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
vae_loss = self._vae_model.loss_function(result, kld_weight=0.5, predict_weight=10) # TODO(pu):weight result['recons_action'].shape[0] * 0.02
), :] # values, indices
self.c_percentage_bound_upper = data['latent_action'].sort(
dim=0
)[0][int(result['recons_action'].shape[0] * 0.98), :]
vae_loss = self._vae_model.loss_function(
result, kld_weight=0.01, predict_weight=0.01
) # TODO(pu): weight
loss_dict['vae_loss'] = vae_loss['loss'] loss_dict['vae_loss'] = vae_loss['loss']
loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'] loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
...@@ -375,22 +365,28 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -375,22 +365,28 @@ class TD3VAEPolicy(DDPGPolicy):
# ==================== # ====================
if self._cuda: if self._cuda:
data = to_device(data, self._device) data = to_device(data, self._device)
result = self._vae_model( result = self._vae_model({'action': data['action'], 'obs': data['obs']})
{'action': data['action'],
'obs': data['obs']})
true_residual = data['next_obs'] - data['obs'] true_residual = data['next_obs'] - data['obs']
# Representation shift correction (RSC) # Representation shift correction (RSC)
for i in range(result['recons_action'].shape[0]): for i in range(result['recons_action'].shape[0]):
if F.mse_loss(result['prediction_residual'][i], true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean: if F.mse_loss(result['prediction_residual'][i],
data['latent_action'][i] = torch.tanh(result['z'][i].detach()) # TODO(pu): update latent_action z tanh true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
# data['latent_action'] = result['mu'].detach() # TODO(pu): update latent_action mu # NOTE: using tanh is important, update latent_action using z
data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach())
# data['latent_action'][i] = result['z'][i].clone().detach()
if self._reward_batch_norm: if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8) reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value # current q value
q_value = self._learn_model.forward({'obs': data['obs'], 'action': data['latent_action']}, mode='compute_critic')['q_value'] q_value = self._learn_model.forward(
{
'obs': data['obs'],
'action': data['latent_action']
}, mode='compute_critic'
)['q_value']
q_value_dict = {} q_value_dict = {}
if self._twin_critic: if self._twin_critic:
q_value_dict['q_value'] = q_value[0].mean() q_value_dict['q_value'] = q_value[0].mean()
...@@ -399,7 +395,8 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -399,7 +395,8 @@ class TD3VAEPolicy(DDPGPolicy):
q_value_dict['q_value'] = q_value.mean() q_value_dict['q_value'] = q_value.mean()
# target q value. # target q value.
with torch.no_grad(): with torch.no_grad():
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') # latent action # NOTE: here next_actor_data['action'] is latent action
next_actor_data = self._target_model.forward(next_obs, mode='compute_actor')
next_actor_data['obs'] = next_obs next_actor_data['obs'] = next_obs
target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value'] target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
if self._twin_critic: if self._twin_critic:
...@@ -432,9 +429,8 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -432,9 +429,8 @@ class TD3VAEPolicy(DDPGPolicy):
# =============================== # ===============================
# actor updates every ``self._actor_update_freq`` iters # actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
# NOTE: actor_data['action] is latent action
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') # latent action actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
actor_data['obs'] = data['obs'] actor_data['obs'] = data['obs']
if self._twin_critic: if self._twin_critic:
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean() actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
...@@ -460,7 +456,6 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -460,7 +456,6 @@ class TD3VAEPolicy(DDPGPolicy):
return { return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'], 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'], 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': action_log_value, 'action': action_log_value,
'priority': td_error_per_sample.abs().tolist(), 'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(), 'td_error': td_error_per_sample.abs().mean(),
...@@ -527,27 +522,24 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -527,27 +522,24 @@ class TD3VAEPolicy(DDPGPolicy):
output['latent_action'] = output['action'] output['latent_action'] = output['action']
# latent space constraint (LSC) # latent space constraint (LSC)
# for i in range(output['action'].shape[-1]): for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(), output['action'][:, i].clamp_(
# self.c_percentage_bound_upper[i].item()) self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
)
# TODO(pu): decode into original hybrid actions, here data is obs # TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase # this is very important to generate self.obs_encoding using in decode phase
output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0] output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
# add noise in the original actions # NOTE: add noise in the original actions
from ding.rl_utils.exploration import GaussianNoise from ding.rl_utils.exploration import GaussianNoise
action = output['action'] action = output['action']
gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1) gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1)
# gaussian_noise = GaussianNoise(mu=0.0, sigma=0.5) noise = gaussian_noise(output['action'].shape, output['action'].device)
noise = gaussian_noise( output['action'].shape, output['action'].device)
if self._cfg.learn.noise_range is not None: if self._cfg.learn.noise_range is not None:
noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max']) noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max'])
action += noise action += noise
self.action_range = { self.action_range = {'min': -1, 'max': 1}
'min': -1,
'max': 1
}
if self.action_range is not None: if self.action_range is not None:
action = action.clamp(self.action_range['min'], self.action_range['max']) action = action.clamp(self.action_range['min'], self.action_range['max'])
output['action'] = action output['action'] = action
...@@ -569,7 +561,6 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -569,7 +561,6 @@ class TD3VAEPolicy(DDPGPolicy):
Return: Return:
- transition (:obj:`Dict[str, Any]`): Dict type transition data. - transition (:obj:`Dict[str, Any]`): Dict type transition data.
""" """
# if hasattr(model_output, 'latent_action'):
if 'latent_action' in model_output.keys(): if 'latent_action' in model_output.keys():
transition = { transition = {
'obs': obs, 'obs': obs,
...@@ -584,7 +575,7 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -584,7 +575,7 @@ class TD3VAEPolicy(DDPGPolicy):
'obs': obs, 'obs': obs,
'next_obs': timestep.obs, 'next_obs': timestep.obs,
'action': model_output['action'], 'action': model_output['action'],
'latent_action': False, 'latent_action': 999,
'reward': timestep.reward, 'reward': timestep.reward,
'done': timestep.done, 'done': timestep.done,
} }
...@@ -629,9 +620,10 @@ class TD3VAEPolicy(DDPGPolicy): ...@@ -629,9 +620,10 @@ class TD3VAEPolicy(DDPGPolicy):
output['latent_action'] = output['action'] output['latent_action'] = output['action']
# latent space constraint (LSC) # latent space constraint (LSC)
# for i in range(output['action'].shape[-1]): for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(), output['action'][:, i].clamp_(
# self.c_percentage_bound_upper[i].item()) self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
)
# TODO(pu): decode into original hybrid actions, here data is obs # TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase # this is very important to generate self.obs_encoding using in decode phase
......
...@@ -227,6 +227,11 @@ class SampleSerialCollector(ISerialCollector): ...@@ -227,6 +227,11 @@ class SampleSerialCollector(ISerialCollector):
if self._transform_obs: if self._transform_obs:
obs = to_tensor(obs, dtype=torch.float32) obs = to_tensor(obs, dtype=torch.float32)
policy_output = self._policy.forward(obs, **policy_kwargs) policy_output = self._policy.forward(obs, **policy_kwargs)
if 'latent_action' in policy_output [0].keys():
if policy_output [0]['latent_action'] is False:
print('here')
self._policy_output_pool.update(policy_output) self._policy_output_pool.update(policy_output)
# Interact with env. # Interact with env.
actions = {env_id: output['action'] for env_id, output in policy_output.items()} actions = {env_id: output['action'] for env_id, output in policy_output.items()}
......
...@@ -2,7 +2,7 @@ from easydict import EasyDict ...@@ -2,7 +2,7 @@ from easydict import EasyDict
from ding.entry import serial_pipeline from ding.entry import serial_pipeline
lunarlander_td3_config = dict( lunarlander_td3_config = dict(
exp_name='lunarlander_cont_td3', exp_name='lunarlander_cont_td3_ns256_upcr256_lr3e-4_rbs1e5',
env=dict( env=dict(
env_id='LunarLanderContinuous-v2', env_id='LunarLanderContinuous-v2',
collector_env_num=8, collector_env_num=8,
...@@ -15,7 +15,7 @@ lunarlander_td3_config = dict( ...@@ -15,7 +15,7 @@ lunarlander_td3_config = dict(
policy=dict( policy=dict(
cuda=False, cuda=False,
priority=False, priority=False,
random_collect_size=800, random_collect_size=0,
model=dict( model=dict(
obs_shape=8, obs_shape=8,
action_shape=2, action_shape=2,
...@@ -23,11 +23,11 @@ lunarlander_td3_config = dict( ...@@ -23,11 +23,11 @@ lunarlander_td3_config = dict(
actor_head_type='regression', actor_head_type='regression',
), ),
learn=dict( learn=dict(
update_per_collect=2, update_per_collect=256,
batch_size=128, batch_size=128,
learning_rate_actor=0.001, learning_rate_actor=3e-4,
learning_rate_critic=0.001, learning_rate_critic=3e-4,
ignore_done=False, # TODO(pu) ignore_done=False,
actor_update_freq=2, actor_update_freq=2,
noise=True, noise=True,
noise_sigma=0.1, noise_sigma=0.1,
...@@ -37,12 +37,13 @@ lunarlander_td3_config = dict( ...@@ -37,12 +37,13 @@ lunarlander_td3_config = dict(
), ),
), ),
collect=dict( collect=dict(
n_sample=48, n_sample=256,
noise_sigma=0.1, noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ), collector=dict(collect_print_freq=1000, ),
), ),
eval=dict(evaluator=dict(eval_freq=100, ), ), eval=dict(evaluator=dict(eval_freq=100, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ), other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
), ),
) )
lunarlander_td3_config = EasyDict(lunarlander_td3_config) lunarlander_td3_config = EasyDict(lunarlander_td3_config)
......
...@@ -2,15 +2,10 @@ from easydict import EasyDict ...@@ -2,15 +2,10 @@ from easydict import EasyDict
from ding.entry import serial_pipeline_td3_vae from ding.entry import serial_pipeline_td3_vae
lunarlander_td3vae_config = dict( lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', exp_name='lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc3_upcr256_upcv100_kw0.01_pw0.01_dot_tanh',
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc_lsc',# TODO(pu)
env=dict( env=dict(
env_id='LunarLanderContinuous-v2', env_id='LunarLanderContinuous-v2',
# collector_env_num=8, collector_env_num=8,
# evaluator_env_num=5,
collector_env_num=1,
evaluator_env_num=5, evaluator_env_num=5,
# (bool) Scale output action into legal range. # (bool) Scale output action into legal range.
act_scale=True, act_scale=True,
...@@ -18,48 +13,29 @@ lunarlander_td3vae_config = dict( ...@@ -18,48 +13,29 @@ lunarlander_td3vae_config = dict(
stop_value=200, stop_value=200,
), ),
policy=dict( policy=dict(
cuda=False, cuda=True,
priority=False, priority=False,
random_collect_size=12800, random_collect_size=10000,
# random_collect_size=0,
original_action_shape=2, original_action_shape=2,
model=dict( model=dict(
obs_shape=8, obs_shape=8,
action_shape=6, # latent_action_dim action_shape=6, # latent_action_dim
twin_critic=True, twin_critic=True,
actor_head_type='regression', actor_head_type='regression',
), ),
learn=dict( learn=dict(
# warm_up_update=0, warm_up_update=int(1e4),
warm_up_update=1000, rl_vae_update_circle=3, # train rl 3 iter, vae 1 iter
# vae_train_times_per_update=1, # TODO(pu) update_per_collect_rl=256,
# rl_vae_update_circle=1000,
rl_vae_update_circle=100, # train rl 100 iter, vae 1 iter
# rl_vae_update_circle=1,
# update_per_collect_rl=50,
update_per_collect_rl=20,
# update_per_collect_rl=2,
update_per_collect_vae=100, update_per_collect_vae=100,
# update_per_collect_vae=20,
# update_per_collect_vae=1,
# update_per_collect_vae=0,
batch_size=128, batch_size=128,
learning_rate_actor=1e-3, learning_rate_actor=3e-4,
learning_rate_critic=1e-3, learning_rate_critic=3e-4,
# learning_rate_actor=3e-4, learning_rate_vae=1e-4,
# learning_rate_critic=3e-4, ignore_done=False,
learning_rate_vae=3e-4,
ignore_done=False, # TODO(pu)
target_theta=0.005, target_theta=0.005,
# actor_update_freq=2, actor_update_freq=2,
actor_update_freq=1, noise=True,
# noise=True,
noise=False, # TODO(pu)
noise_sigma=0.1, noise_sigma=0.1,
noise_range=dict( noise_range=dict(
min=-0.5, min=-0.5,
...@@ -67,18 +43,13 @@ lunarlander_td3vae_config = dict( ...@@ -67,18 +43,13 @@ lunarlander_td3vae_config = dict(
), ),
), ),
collect=dict( collect=dict(
# each_iter_n_sample=48, # 1280 n_sample=256,
n_sample=48, unroll_len=1,
unroll_len=1, # TODO(pu) noise_sigma=0, # NOTE: add noise in original action in _forward_collect method of td3_vae policy
# noise_sigma=0.1,
noise_sigma=0, # TODO(pu)
collector=dict(collect_print_freq=1000, ), collector=dict(collect_print_freq=1000, ),
), ),
eval=dict(evaluator=dict(eval_freq=100, ), ), eval=dict(evaluator=dict(eval_freq=100, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=int(2e4), ), ), other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
# other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
# other=dict(replay_buffer=dict(replay_buffer_size=int(5e5), ), ),
), ),
) )
lunarlander_td3vae_config = EasyDict(lunarlander_td3vae_config) lunarlander_td3vae_config = EasyDict(lunarlander_td3vae_config)
......
from typing import Any, List, Union, Optional from typing import Any, List, Union, Optional
import time import time
import gym import gym
...@@ -6,6 +7,7 @@ from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo ...@@ -6,6 +7,7 @@ from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo
from ding.envs.common.env_element import EnvElement, EnvElementInfo from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.torch_utils import to_ndarray, to_list from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY from ding.utils import ENV_REGISTRY
from ding.envs.common import affine_transform
@ENV_REGISTRY.register('lunarlander') @ENV_REGISTRY.register('lunarlander')
...@@ -14,6 +16,7 @@ class LunarLanderEnv(BaseEnv): ...@@ -14,6 +16,7 @@ class LunarLanderEnv(BaseEnv):
def __init__(self, cfg: dict) -> None: def __init__(self, cfg: dict) -> None:
self._cfg = cfg self._cfg = cfg
self._init_flag = False self._init_flag = False
self._act_scale = cfg.act_scale
def reset(self) -> np.ndarray: def reset(self) -> np.ndarray:
if not self._init_flag: if not self._init_flag:
...@@ -49,6 +52,8 @@ class LunarLanderEnv(BaseEnv): ...@@ -49,6 +52,8 @@ class LunarLanderEnv(BaseEnv):
assert isinstance(action, np.ndarray), type(action) assert isinstance(action, np.ndarray), type(action)
if action.shape == (1, ): if action.shape == (1, ):
action = action.squeeze() # 0-dim array action = action.squeeze() # 0-dim array
if self._act_scale:
action = affine_transform(action, min_val=-1, max_val=1)
obs, rew, done, info = self._env.step(action) obs, rew, done, info = self._env.step(action)
# self._env.render() # self._env.render()
rew = float(rew) rew = float(rew)
......