提交 92129676 编写于 作者: P puyuan1996

feature(pu): representaion shift correction for each transition

上级 2cfc411e
......@@ -384,12 +384,13 @@ class TD3VAEPolicy(DDPGPolicy):
{'action': data['action'],
'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# if result[1].detach()
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[5].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
true_residual = data['next_obs'] - data['obs']
if F.mse_loss(result[1], true_residual).item() > 4 * self._running_mean_std_predict_loss.mean:
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
for i in range(result[1].shape[0]):
if F.mse_loss(result[1][i], true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
data['latent_action'][i] = result[5][i].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
......
......@@ -30,10 +30,13 @@ lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc10000_upcr2_upcv10000_notargetnoise_collectoriginalnoise_rbs5e5_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run4
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run2
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run6
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run2
exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run3
env=dict(
env_id='LunarLanderContinuous-v2',
......@@ -70,8 +73,8 @@ lunarlander_td3vae_config = dict(
# rl_vae_update_circle=1,
# update_per_collect_rl=50,
# update_per_collect_rl=20,
update_per_collect_rl=2,
update_per_collect_rl=20,
# update_per_collect_rl=2,
update_per_collect_vae=1000, # each mini-batch: replay_buffer_recent sample 128, replay_buffer sample 128
# update_per_collect_vae=20,
......
......@@ -7,7 +7,7 @@ evaluator_env_num = 8
special_global_state = True
SMAC_5m6m_masac_default_config = dict(
exp_name='smac_5m6m_masac_alpha_learn_rate_4',
exp_name='debug_smac_5m6m_masac_d5e4',
env=dict(
map_name='5m_vs_6m',
difficulty=7,
......@@ -27,7 +27,8 @@ SMAC_5m6m_masac_default_config = dict(
),
policy=dict(
cuda=True,
random_collect_size=0,
# random_collect_size=0,
random_collect_size=int(1e4),
model=dict(
agent_obs_shape=72,
global_obs_shape=152,
......@@ -63,10 +64,9 @@ SMAC_5m6m_masac_default_config = dict(
type='linear',
start=1,
end=0.05,
decay=50000,
decay=int(5e4),
),
replay_buffer=dict(replay_buffer_size=1000000, ),
),
replay_buffer=dict(replay_buffer_size=int(1e6), ), ),
),
)
......@@ -84,5 +84,22 @@ SMAC_5m6m_masac_default_create_config = dict(
SMAC_5m6m_masac_default_create_config = EasyDict(SMAC_5m6m_masac_default_create_config)
create_config = SMAC_5m6m_masac_default_create_config
# if __name__ == "__main__":
# serial_pipeline([main_config, create_config], seed=0)
def train(args):
main_config.exp_name='debug_smac_5m6m_masac_'+'_seed'+f'{args.seed}'+'_rcs1e4'
# serial_pipeline([main_config, create_config], seed=args.seed)
import copy
serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed)
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
import argparse
for seed in [1]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', '-s', type=int, default=seed)
args = parser.parse_args()
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册