Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
18f86f26
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
18f86f26
编写于
12月 14, 2021
作者:
P
puyuan1996
提交者:
niuyazhe
12月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(pu): delete noise and change the data for updating vae
上级
5112584b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
22 deletion
+43
-22
ding/entry/serial_entry_td3_vae.py
ding/entry/serial_entry_td3_vae.py
+9
-6
ding/policy/td3_vae.py
ding/policy/td3_vae.py
+10
-4
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
...x2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
+24
-12
未找到文件。
ding/entry/serial_entry_td3_vae.py
浏览文件 @
18f86f26
...
...
@@ -131,10 +131,10 @@ def serial_pipeline_td3_vae(
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
# rl phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,
1
0):
if
iter
%
cfg
.
policy
.
learn
.
rl_vae_update_circle
in
range
(
0
,
cfg
.
policy
.
learn
.
rl_vae_update_circle
-
1
):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(0,
2
0):
if
iter
%
cfg
.
policy
.
learn
.
rl_vae_update_circle
in
range
(
0
,
cfg
.
policy
.
learn
.
rl_vae_update_circle
):
# Learn policy from collected data
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect_rl
):
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect_rl
):
#2->12
# Learner will train ``update_per_collect`` times in one iteration.
train_data
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
for
item
in
train_data
:
...
...
@@ -151,11 +151,14 @@ def serial_pipeline_td3_vae(
if
learner
.
policy
.
get_attribute
(
'priority'
):
replay_buffer
.
update
(
learner
.
priority_info
)
# vae phase
# if iter % cfg.policy.learn.rl_vae_update_circle in range(10, 11):
if
iter
%
cfg
.
policy
.
learn
.
rl_vae_update_circle
in
range
(
cfg
.
policy
.
learn
.
rl_vae_update_circle
-
1
,
cfg
.
policy
.
learn
.
rl_vae_update_circle
):
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect_vae
):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(19, 20):
# if iter % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1, cfg.policy.learn.rl_vae_update_circle):
if
iter
%
cfg
.
policy
.
learn
.
rl_vae_update_circle
in
range
(
cfg
.
policy
.
learn
.
rl_vae_update_circle
-
1
,
cfg
.
policy
.
learn
.
rl_vae_update_circle
):
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect_vae
):
#40
# Learner will train ``update_per_collect`` times in one iteration.
train_data
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
train_data
=
train_data
+
new_data
# TODO(pu)
for
item
in
train_data
:
item
[
'rl_phase'
]
=
False
item
[
'vae_phase'
]
=
True
...
...
ding/policy/td3_vae.py
浏览文件 @
18f86f26
...
...
@@ -219,7 +219,10 @@ class TD3VAEPolicy(DDPGPolicy):
self
.
_forward_learn_cnt
=
0
# count iterations
# action_shape, obs_shape, action_latent_dim, hidden_size_list
# self._vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256, 256])
# self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])
self
.
_vae_model
=
VanillaVAE
(
2
,
8
,
6
,
[
256
,
256
,
256
])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self
.
_optimizer_vae
=
Adam
(
self
.
_vae_model
.
parameters
(),
lr
=
self
.
_cfg
.
learn
.
learning_rate_vae
,
...
...
@@ -254,7 +257,8 @@ class TD3VAEPolicy(DDPGPolicy):
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action mu
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result
.
pop
(
-
1
)
# remove z
result
[
2
]
=
data
[
'action'
]
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
...
...
@@ -307,7 +311,7 @@ class TD3VAEPolicy(DDPGPolicy):
)
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_rl_update_circle in range(10,15):
# if (self._forward_learn_cnt + 1) % self._cfg.learn.vae_update_freq == 0:
if
data
[
'vae_phase'
][
0
]
is
True
:
if
data
[
'vae_phase'
][
0
]
.
item
()
is
True
:
# for i in range(self._cfg.learn.vae_train_times_per_update):
if
self
.
_cuda
:
data
=
to_device
(
data
,
self
.
_device
)
...
...
@@ -319,7 +323,8 @@ class TD3VAEPolicy(DDPGPolicy):
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action z
# data['latent_action'] = result[3].detach() # TODO(pu): update latent_action mu
result
.
pop
(
-
1
)
# remove z
result
[
2
]
=
data
[
'action'
]
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
...
...
@@ -336,6 +341,7 @@ class TD3VAEPolicy(DDPGPolicy):
loss_dict
[
'vae_loss'
]
=
vae_loss
[
'loss'
]
loss_dict
[
'reconstruction_loss'
]
=
vae_loss
[
'reconstruction_loss'
]
loss_dict
[
'kld_loss'
]
=
vae_loss
[
'kld_loss'
]
loss_dict
[
'predict_loss'
]
=
vae_loss
[
'predict_loss'
]
# vae update
self
.
_optimizer_vae
.
zero_grad
()
...
...
@@ -597,7 +603,7 @@ class TD3VAEPolicy(DDPGPolicy):
"""
ret
=
[
'cur_lr_actor'
,
'cur_lr_critic'
,
'critic_loss'
,
'actor_loss'
,
'total_loss'
,
'q_value'
,
'q_value_twin'
,
'action'
,
'td_error'
,
'vae_loss'
,
'reconstruction_loss'
,
'kld_loss'
'action'
,
'td_error'
,
'vae_loss'
,
'reconstruction_loss'
,
'kld_loss'
,
'predict_loss'
]
if
self
.
_twin_critic
:
ret
+=
[
'critic_twin_loss'
]
...
...
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
浏览文件 @
18f86f26
...
...
@@ -2,10 +2,18 @@ from easydict import EasyDict
from
ding.entry
import
serial_pipeline_td3_vae
lunarlander_td3vae_config
=
dict
(
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc11_upcv10',
exp_name
=
'lunarlander_cont_td3_vae_lad6_wu0_rvuc11_upcv20'
,
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc21_upcv40',
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc3_upcv4',
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc2_upcv4', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc10_upcv20', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc20_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu0_rvuc30_upcv60', # worse
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcv150', # TODO(pu) eval reward_mean -132.63不变
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_mu_rvuc20_upcr12_upcv40', # TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv40', # TODO(pu)
exp_name
=
'lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv40_noisefalse'
,
# TODO(pu)
# exp_name='lunarlander_cont_td3_vae_lad6_wu1000_z_rvuc20_upcr12_upcv150', # TODO(pu)
env
=
dict
(
env_id
=
'LunarLanderContinuous-v2'
,
...
...
@@ -21,8 +29,8 @@ lunarlander_td3vae_config = dict(
policy
=
dict
(
cuda
=
False
,
priority
=
False
,
# random_collect_size=128
0,
random_collect_size
=
0
,
random_collect_size
=
1280
0
,
#
random_collect_size=0,
original_action_shape
=
2
,
model
=
dict
(
obs_shape
=
8
,
...
...
@@ -31,12 +39,15 @@ lunarlander_td3vae_config = dict(
actor_head_type
=
'regression'
,
),
learn
=
dict
(
warm_up_update
=
0
,
# warm_up_update=100,
rl_vae_update_circle
=
11
,
# train rl 10 iter, vae 1 iter
# warm_up_update=0,
warm_up_update
=
1000
,
# vae_train_times_per_update=1, # TODO(pu)
update_per_collect_rl
=
2
,
update_per_collect_vae
=
20
,
rl_vae_update_circle
=
20
,
# train rl 20 iter, vae 1 iter
# update_per_collect_rl=2,
update_per_collect_rl
=
12
,
update_per_collect_vae
=
40
,
# update_per_collect_vae=150,
batch_size
=
128
,
learning_rate_actor
=
3e-4
,
...
...
@@ -44,7 +55,8 @@ lunarlander_td3vae_config = dict(
learning_rate_vae
=
3e-4
,
ignore_done
=
False
,
# TODO(pu)
actor_update_freq
=
2
,
noise
=
True
,
# noise=True,
noise
=
False
,
# TODO(pu)
noise_sigma
=
0.1
,
noise_range
=
dict
(
min
=-
0.5
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录