Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
92129676
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
92129676
编写于
12月 20, 2021
作者:
P
puyuan1996
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(pu): representaion shift correction for each transition
上级
2cfc411e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
37 addition
and
16 deletion
+37
-16
ding/policy/td3_vae.py
ding/policy/td3_vae.py
+5
-4
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
...x2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
+9
-6
dizoo/smac/config/smac_5m6m_masac_config.py
dizoo/smac/config/smac_5m6m_masac_config.py
+23
-6
未找到文件。
ding/policy/td3_vae.py
浏览文件 @
92129676
...
...
@@ -384,12 +384,13 @@ class TD3VAEPolicy(DDPGPolicy):
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
# if result[1].detach()
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action z
#
data['latent_action'] = result[5].detach() # TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
if
F
.
mse_loss
(
result
[
1
],
true_residual
).
item
()
>
4
*
self
.
_running_mean_std_predict_loss
.
mean
:
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
for
i
in
range
(
result
[
1
].
shape
[
0
]):
if
F
.
mse_loss
(
result
[
1
][
i
],
true_residual
[
i
]).
item
()
>
4
*
self
.
_running_mean_std_predict_loss
.
mean
:
data
[
'latent_action'
][
i
]
=
result
[
5
][
i
].
detach
()
# TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
if
self
.
_reward_batch_norm
:
reward
=
(
reward
-
reward
.
mean
())
/
(
reward
.
std
()
+
1e-8
)
...
...
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
浏览文件 @
92129676
...
...
@@ -30,10 +30,13 @@ lunarlander_td3vae_config = dict(
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc10000_upcr2_upcv10000_notargetnoise_collectoriginalnoise_rbs5e5_rsc',
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run4
exp_name
=
'lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc'
,
# TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run2
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run6
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr2_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr2_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # TODO(pu) run3 1.5m collect rew_max eval rew_mean
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv1_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu0_relabelz_novaeupdatez_ns48_rvuc1_upcr2_upcv0_notargetnoise_collectoriginalnoise_rbs2e4_rsc', #
# exp_name='lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc100_upcr20_upcv100_notargetnoise_collectoriginalnoise_rbs2e4_rsc', # run2
exp_name
=
'lunarlander_cont_ddpg_vae_lad6_wu1000_rlabelz_novaeupdatez_ns48_rvuc1000_upcr20_upcv1000_notargetnoise_collectoriginalnoise_rbs2e4_rsc'
,
# run3
env
=
dict
(
env_id
=
'LunarLanderContinuous-v2'
,
...
...
@@ -70,8 +73,8 @@ lunarlander_td3vae_config = dict(
# rl_vae_update_circle=1,
# update_per_collect_rl=50,
#
update_per_collect_rl=20,
update_per_collect_rl
=
2
,
update_per_collect_rl
=
20
,
#
update_per_collect_rl=2,
update_per_collect_vae
=
1000
,
# each mini-batch: replay_buffer_recent sample 128, replay_buffer sample 128
# update_per_collect_vae=20,
...
...
dizoo/smac/config/smac_5m6m_masac_config.py
浏览文件 @
92129676
...
...
@@ -7,7 +7,7 @@ evaluator_env_num = 8
special_global_state
=
True
SMAC_5m6m_masac_default_config
=
dict
(
exp_name
=
'
smac_5m6m_masac_alpha_learn_rate_
4'
,
exp_name
=
'
debug_smac_5m6m_masac_d5e
4'
,
env
=
dict
(
map_name
=
'5m_vs_6m'
,
difficulty
=
7
,
...
...
@@ -27,7 +27,8 @@ SMAC_5m6m_masac_default_config = dict(
),
policy
=
dict
(
cuda
=
True
,
random_collect_size
=
0
,
# random_collect_size=0,
random_collect_size
=
int
(
1e4
),
model
=
dict
(
agent_obs_shape
=
72
,
global_obs_shape
=
152
,
...
...
@@ -63,10 +64,9 @@ SMAC_5m6m_masac_default_config = dict(
type
=
'linear'
,
start
=
1
,
end
=
0.05
,
decay
=
50000
,
decay
=
int
(
5e4
)
,
),
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
replay_buffer
=
dict
(
replay_buffer_size
=
int
(
1e6
),
),
),
),
)
...
...
@@ -84,5 +84,22 @@ SMAC_5m6m_masac_default_create_config = dict(
SMAC_5m6m_masac_default_create_config
=
EasyDict
(
SMAC_5m6m_masac_default_create_config
)
create_config
=
SMAC_5m6m_masac_default_create_config
# if __name__ == "__main__":
# serial_pipeline([main_config, create_config], seed=0)
def
train
(
args
):
main_config
.
exp_name
=
'debug_smac_5m6m_masac_'
+
'_seed'
+
f
'
{
args
.
seed
}
'
+
'_rcs1e4'
# serial_pipeline([main_config, create_config], seed=args.seed)
import
copy
serial_pipeline
([
copy
.
deepcopy
(
main_config
),
copy
.
deepcopy
(
create_config
)],
seed
=
args
.
seed
)
if
__name__
==
"__main__"
:
serial_pipeline
([
main_config
,
create_config
],
seed
=
0
)
import
argparse
for
seed
in
[
1
]:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--seed'
,
'-s'
,
type
=
int
,
default
=
seed
)
args
=
parser
.
parse_args
()
train
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录