Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
b93a380b
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b93a380b
编写于
12月 20, 2021
作者:
P
puyuan1996
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(pu): polish config
上级
ec3a3618
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
60 deletion
+26
-60
ding/model/template/vae.py
ding/model/template/vae.py
+2
-1
ding/policy/td3_vae.py
ding/policy/td3_vae.py
+22
-16
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
...x2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
+2
-43
未找到文件。
ding/model/template/vae.py
浏览文件 @
b93a380b
...
...
@@ -204,7 +204,8 @@ class VanillaVAE(BaseVAE):
def
forward
(
self
,
input
:
Tensor
,
**
kwargs
)
->
List
[
Tensor
]:
mu
,
log_var
=
self
.
encode
(
input
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
return
[
self
.
decode
(
z
)[
0
],
self
.
decode
(
z
)[
1
],
input
,
mu
,
log_var
,
z
]
return
[
self
.
decode
(
z
)[
0
],
self
.
decode
(
z
)[
1
],
input
,
mu
,
log_var
,
z
]
# recons_action, prediction_residual
def
loss_function
(
self
,
*
args
,
...
...
ding/policy/td3_vae.py
浏览文件 @
b93a380b
...
...
@@ -231,6 +231,8 @@ class TD3VAEPolicy(DDPGPolicy):
lr
=
self
.
_cfg
.
learn
.
learning_rate_vae
,
)
self
.
_running_mean_std_predict_loss
=
RunningMeanStd
(
epsilon
=
1e-4
)
self
.
c_percentage_bound_lower
=
-
1
*
torch
.
ones
([
6
])
self
.
c_percentage_bound_upper
=
torch
.
ones
([
6
])
def
_forward_learn
(
self
,
data
:
dict
)
->
Dict
[
str
,
Any
]:
r
"""
...
...
@@ -314,8 +316,6 @@ class TD3VAEPolicy(DDPGPolicy):
ignore_done
=
self
.
_cfg
.
learn
.
ignore_done
,
use_nstep
=
False
)
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(10,15):
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_update_freq == 0:
if
data
[
'vae_phase'
][
0
].
item
()
is
True
:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if
self
.
_cuda
:
...
...
@@ -335,6 +335,11 @@ class TD3VAEPolicy(DDPGPolicy):
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
result
=
result
+
[
true_residual
]
# latent space constraint (LSC)
# data['latent_action'] = torch.tanh(result[5].detach()) # TODO(pu): update latent_action z, shape (128,6)
# self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(result[0].shape[0] * 0.02), :] # values, indices
# self.c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(result[0].shape[0] * 0.98), :]
vae_loss
=
self
.
_vae_model
.
loss_function
(
*
result
,
kld_weight
=
0.5
,
predict_weight
=
10
)
# TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
...
...
@@ -364,8 +369,6 @@ class TD3VAEPolicy(DDPGPolicy):
**
q_value_dict
,
}
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(0,10):
# if data[0]['rl_phase'] is True:
else
:
# ====================
# critic learn forward
...
...
@@ -384,6 +387,8 @@ class TD3VAEPolicy(DDPGPolicy):
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
# Representation shift correction (RSC)
for
i
in
range
(
result
[
1
].
shape
[
0
]):
if
F
.
mse_loss
(
result
[
1
][
i
],
true_residual
[
i
]).
item
()
>
4
*
self
.
_running_mean_std_predict_loss
.
mean
:
data
[
'latent_action'
][
i
]
=
torch
.
tanh
(
result
[
5
][
i
].
detach
())
# TODO(pu): update latent_action z tanh
...
...
@@ -435,20 +440,9 @@ class TD3VAEPolicy(DDPGPolicy):
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if
(
self
.
_forward_learn_cnt
+
1
)
%
self
.
_actor_update_freq
==
0
:
# latent space constraint (LSC)
# result = self._vae_model(
# {'action': data['action'],
# 'obs': data['obs']}) # [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# data['latent_action'] = torch.tanh(result[5].detach()) # TODO(pu): update latent_action z
# c_percentage_bound_low = data['latent_action'].sort(dim=0)[0][int(128 * 0.02), :]
# c_percentage_bound_upper = data['latent_action'].sort(dim=0)[0][int(128 * 0.98), :]
actor_data
=
self
.
_learn_model
.
forward
(
data
[
'obs'
],
mode
=
'compute_actor'
)
# latent action
# latent space constraint (LSC)
# for i in range(actor_data['action'].shape[-1]):
# # actor_data['action'][:, i] = copy.deepcopy(actor_data['action'][:, i].clamp(c_percentage_bound_low[i].item(), c_percentage_bound_upper[i].item()))
# actor_data['action'][:, i].clamp(c_percentage_bound_low[i].item(),
# c_percentage_bound_upper[i].item())
actor_data
[
'obs'
]
=
data
[
'obs'
]
if
self
.
_twin_critic
:
...
...
@@ -540,6 +534,12 @@ class TD3VAEPolicy(DDPGPolicy):
with
torch
.
no_grad
():
output
=
self
.
_collect_model
.
forward
(
data
,
mode
=
'compute_actor'
,
**
kwargs
)
output
[
'latent_action'
]
=
output
[
'action'
]
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
output
[
'action'
]
=
self
.
_vae_model
.
decode_with_obs
(
output
[
'action'
],
data
)[
0
]
...
...
@@ -636,6 +636,12 @@ class TD3VAEPolicy(DDPGPolicy):
with
torch
.
no_grad
():
output
=
self
.
_eval_model
.
forward
(
data
,
mode
=
'compute_actor'
)
output
[
'latent_action'
]
=
output
[
'action'
]
# latent space constraint (LSC)
# for i in range(output['action'].shape[-1]):
# output['action'][:, i].clamp_(self.c_percentage_bound_lower[i].item(),
# self.c_percentage_bound_upper[i].item())
# TODO(pu): decode into original hybrid actions, here data is obs
# this is very important to generate self.obs_encoding using in decode phase
output
[
'action'
]
=
self
.
_vae_model
.
decode_with_obs
(
output
[
'action'
],
data
)[
0
]
...
...
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
浏览文件 @
b93a380b
...
...
@@ -2,50 +2,9 @@ from easydict import EasyDict
from
ding.entry
import
serial_pipeline_td3_vae
lunarlander_td3vae_config
=
dict
(
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_zrelabel_eins1280_rvuc10_upcr20_upcv100_noisefalse_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins1280_rvuc10_upcr20_upcv100_noisefalse_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1_upcr2_upcv2_noisetrue_rbs2e4', # TODO(pu) lr 3e-4 loss explode 45000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc20_upcr2_upcv200_noisetrue_rbs2e4', # TODO(pu) lr 3e-4 loss explode 10000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1000_upcr2_upcv1000_noisetrue_rbs1e5', # TODO(pu) loss explode 10000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_murelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu) loss explode 3000iters
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_zrelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu) 80000iters eval rew_mean -278
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs1e5', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc1000_upcr2_upcv0_noisetrue_rbs2e4', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc100_upcr2_upcv0_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) debug 2m collect rew_max 200
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_eins48_rvuc100_upcr2_upcv0_targetnoise_nocollectnoise_rbs2e4', # TODO(pu) 2m collect rew_mean -120 不变
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_norelabel_vaeupdatez_eins48_rvuc100_upcr2_upcv100_noisetrue_rbs2e4', # TODO(pu) 90000iters eval rew_mean -139
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr2_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 2m eval rew_mean -210 best
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr20_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 3m collect rew_mean -254 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc100_upcr50_upcv100_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 1m collect rew_mean -277 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr2_upcv1000_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 0.5m eval rew_mean -43 best
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr20_upcv1000_notargetnoise_nocollectnoise_rbs2e4', # TODO(pu) 3m collect rew_mean -256 loss explode
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_relabelz_novaeupdatez_eins48_rvuc1000_upcr20_upcv1000_notargetnoise_nocollectnoise_rbs1e5', # TODO(pu) 3m collect rew_mean -259
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc10000_upcr2_upcv10000_notargetnoise_collectoriginalnoise_rbs5e5_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs1e5_rsc',# TODO(pu) run6 best
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns1280_rvuc100_upcr50_upcv100_notargetnoise_collectoriginalnoise_rbs1e5_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr50_upcv1000_notargetnoise_collectoriginalnoise_rbs1e5_rsc', # run4
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upc2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc',# TODO(pu) deubg
exp_name
=
'lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upc2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc'
,
# TODO(pu) run3
exp_name
=
'lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc'
,
# TODO(pu) deubg
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc_lsc',# TODO(pu)
env
=
dict
(
env_id
=
'LunarLanderContinuous-v2'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录