Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
4fca50f4
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4fca50f4
编写于
12月 13, 2021
作者:
P
puyuan1996
提交者:
niuyazhe
12月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(pu): vae and rl update alternately
上级
4d240686
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
155 addition
and
128 deletion
+155
-128
ding/model/template/vae.py
ding/model/template/vae.py
+1
-1
ding/policy/td3_vae.py
ding/policy/td3_vae.py
+138
-120
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
...x2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
+16
-7
未找到文件。
ding/model/template/vae.py
浏览文件 @
4fca50f4
...
...
@@ -46,9 +46,9 @@ class VanillaVAE(BaseVAE):
**
kwargs
)
->
None
:
super
(
VanillaVAE
,
self
).
__init__
()
self
.
latent_dim
=
latent_dim
self
.
action_dim
=
in_channels_1
self
.
obs_dim
=
in_channels_2
self
.
latent_dim
=
latent_dim
self
.
hidden_dims
=
hidden_dims
modules
=
[]
...
...
ding/policy/td3_vae.py
浏览文件 @
4fca50f4
...
...
@@ -217,14 +217,14 @@ class TD3VAEPolicy(DDPGPolicy):
self
.
_target_model
.
reset
()
self
.
_forward_learn_cnt
=
0
# count iterations
self
.
_vae_model
=
VanillaVAE
(
2
,
8
,
64
,
[
256
,
256
,
256
])
# action_shape, latent_dim, hidden_size_list
# action_shape, obs_shape, action_latent_dim, hidden_size_list
self
.
_vae_model
=
VanillaVAE
(
self
.
_cfg
.
original_action_shape
,
self
.
_cfg
.
model
.
obs_shape
,
self
.
_cfg
.
model
.
action_shape
,
[
256
,
256
,
256
])
# self._vae_model = VanillaVAE(2, 8, 2, [256, 256, 256])
self
.
_optimizer_vae
=
Adam
(
self
.
_vae_model
.
parameters
(),
lr
=
self
.
_cfg
.
learn
.
learning_rate_vae
,
)
# self.vae_model = VanillaVAE(self._cfg.original_action_shape, self._cfg.obs_shape, self._cfg.model.action_shape, [256, 256])
# action_shape, self.state_dim latent_dim, hidden_size_list
def
_forward_learn
(
self
,
data
:
dict
)
->
Dict
[
str
,
Any
]:
r
"""
...
...
@@ -235,6 +235,7 @@ class TD3VAEPolicy(DDPGPolicy):
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
"""
# warmup phase
if
'warm_up'
in
data
[
0
].
keys
()
and
data
[
0
][
'warm_up'
]
is
True
:
loss_dict
=
{}
data
=
default_preprocess_learn
(
...
...
@@ -295,7 +296,9 @@ class TD3VAEPolicy(DDPGPolicy):
**
q_value_dict
,
}
else
:
self
.
_forward_learn_cnt
+=
1
loss_dict
=
{}
q_value_dict
=
{}
data
=
default_preprocess_learn
(
data
,
use_priority
=
self
.
_cfg
.
priority
,
...
...
@@ -303,124 +306,139 @@ class TD3VAEPolicy(DDPGPolicy):
ignore_done
=
self
.
_cfg
.
learn
.
ignore_done
,
use_nstep
=
False
)
if
self
.
_cuda
:
data
=
to_device
(
data
,
self
.
_device
)
# ====================
# train vae
# ====================
result
=
self
.
_vae_model
(
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action
result
.
pop
(
-
1
)
# remove z
result
[
2
]
=
data
[
'action'
]
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
result
=
result
+
[
true_residual
]
vae_loss
=
self
.
_vae_model
.
loss_function
(
*
result
,
kld_weight
=
0.5
,
predict_weight
=
10
)
# TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
loss_dict
[
'vae_loss'
]
=
vae_loss
[
'loss'
]
loss_dict
[
'reconstruction_loss'
]
=
vae_loss
[
'reconstruction_loss'
]
loss_dict
[
'kld_loss'
]
=
vae_loss
[
'kld_loss'
]
# vae update
self
.
_optimizer_vae
.
zero_grad
()
vae_loss
[
'loss'
].
backward
()
self
.
_optimizer_vae
.
step
()
# ====================
# critic learn forward
# ====================
self
.
_learn_model
.
train
()
self
.
_target_model
.
train
()
next_obs
=
data
[
'next_obs'
]
reward
=
data
[
'reward'
]
if
self
.
_reward_batch_norm
:
reward
=
(
reward
-
reward
.
mean
())
/
(
reward
.
std
()
+
1e-8
)
# current q value
q_value
=
self
.
_learn_model
.
forward
({
'obs'
:
data
[
'obs'
],
'action'
:
data
[
'latent_action'
]},
mode
=
'compute_critic'
)[
'q_value'
]
q_value_dict
=
{}
if
self
.
_twin_critic
:
q_value_dict
[
'q_value'
]
=
q_value
[
0
].
mean
()
q_value_dict
[
'q_value_twin'
]
=
q_value
[
1
].
mean
()
else
:
q_value_dict
[
'q_value'
]
=
q_value
.
mean
()
# target q value.
with
torch
.
no_grad
():
next_actor_data
=
self
.
_target_model
.
forward
(
next_obs
,
mode
=
'compute_actor'
)
# latent action
next_actor_data
[
'obs'
]
=
next_obs
target_q_value
=
self
.
_target_model
.
forward
(
next_actor_data
,
mode
=
'compute_critic'
)[
'q_value'
]
if
self
.
_twin_critic
:
# TD3: two critic networks
target_q_value
=
torch
.
min
(
target_q_value
[
0
],
target_q_value
[
1
])
# find min one as target q value
# critic network1
td_data
=
v_1step_td_data
(
q_value
[
0
],
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_loss
,
td_error_per_sample1
=
v_1step_td_error
(
td_data
,
self
.
_gamma
)
loss_dict
[
'critic_loss'
]
=
critic_loss
# critic network2(twin network)
td_data_twin
=
v_1step_td_data
(
q_value
[
1
],
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_twin_loss
,
td_error_per_sample2
=
v_1step_td_error
(
td_data_twin
,
self
.
_gamma
)
loss_dict
[
'critic_twin_loss'
]
=
critic_twin_loss
td_error_per_sample
=
(
td_error_per_sample1
+
td_error_per_sample2
)
/
2
else
:
# DDPG: single critic network
td_data
=
v_1step_td_data
(
q_value
,
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_loss
,
td_error_per_sample
=
v_1step_td_error
(
td_data
,
self
.
_gamma
)
loss_dict
[
'critic_loss'
]
=
critic_loss
# ================
# critic update
# ================
self
.
_optimizer_critic
.
zero_grad
()
for
k
in
loss_dict
:
if
'critic'
in
k
:
loss_dict
[
k
].
backward
()
self
.
_optimizer_critic
.
step
()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if
(
self
.
_forward_learn_cnt
+
1
)
%
self
.
_actor_update_freq
==
0
:
actor_data
=
self
.
_learn_model
.
forward
(
data
[
'obs'
],
mode
=
'compute_actor'
)
# latent action
actor_data
[
'obs'
]
=
data
[
'obs'
]
if
(
self
.
_forward_learn_cnt
+
1
)
%
self
.
_cfg
.
learn
.
vae_update_freq
in
[
0
,
1
,
2
,
3
,
4
]:
for
i
in
range
(
self
.
_cfg
.
learn
.
train_vae_times_per_update
):
if
self
.
_cuda
:
data
=
to_device
(
data
,
self
.
_device
)
# ====================
# train vae
# ====================
result
=
self
.
_vae_model
(
{
'action'
:
data
[
'action'
],
'obs'
:
data
[
'obs'
]})
# [self.decode(z)[0], self.decode(z)[1], input, mu, log_var, z]
data
[
'latent_action'
]
=
result
[
5
].
detach
()
# TODO(pu): update latent_action
result
.
pop
(
-
1
)
# remove z
result
[
2
]
=
data
[
'action'
]
true_residual
=
data
[
'next_obs'
]
-
data
[
'obs'
]
result
=
result
+
[
true_residual
]
vae_loss
=
self
.
_vae_model
.
loss_function
(
*
result
,
kld_weight
=
0.5
,
predict_weight
=
10
)
# TODO(pu):weight
# recons = args[0]
# prediction_residual = args[1]
# input_action = args[2]
# mu = args[3]
# log_var = args[4]
# true_residual = args[5]
# print(vae_loss)
loss_dict
[
'vae_loss'
]
=
vae_loss
[
'loss'
]
loss_dict
[
'reconstruction_loss'
]
=
vae_loss
[
'reconstruction_loss'
]
loss_dict
[
'kld_loss'
]
=
vae_loss
[
'kld_loss'
]
# vae update
self
.
_optimizer_vae
.
zero_grad
()
vae_loss
[
'loss'
].
backward
()
self
.
_optimizer_vae
.
step
()
return
{
'cur_lr_actor'
:
self
.
_optimizer_actor
.
defaults
[
'lr'
],
'cur_lr_critic'
:
self
.
_optimizer_critic
.
defaults
[
'lr'
],
# 'q_value': np.array(q_value).mean(),
'action'
:
torch
.
Tensor
([
0
]).
item
(),
'priority'
:
torch
.
Tensor
([
0
]).
item
(),
'td_error'
:
torch
.
Tensor
([
0
]).
item
(),
**
loss_dict
,
**
q_value_dict
,
}
if
(
self
.
_forward_learn_cnt
+
1
)
%
self
.
_cfg
.
learn
.
rl_update_freq
in
[
5
,
6
,
7
,
8
,
9
]:
# ====================
# critic learn forward
# ====================
self
.
_learn_model
.
train
()
self
.
_target_model
.
train
()
next_obs
=
data
[
'next_obs'
]
reward
=
data
[
'reward'
]
if
self
.
_reward_batch_norm
:
reward
=
(
reward
-
reward
.
mean
())
/
(
reward
.
std
()
+
1e-8
)
# current q value
q_value
=
self
.
_learn_model
.
forward
({
'obs'
:
data
[
'obs'
],
'action'
:
data
[
'latent_action'
]},
mode
=
'compute_critic'
)[
'q_value'
]
q_value_dict
=
{}
if
self
.
_twin_critic
:
actor_loss
=
-
self
.
_learn_model
.
forward
(
actor_data
,
mode
=
'compute_critic'
)[
'q_value'
][
0
].
mean
()
q_value_dict
[
'q_value'
]
=
q_value
[
0
].
mean
()
q_value_dict
[
'q_value_twin'
]
=
q_value
[
1
].
mean
()
else
:
actor_loss
=
-
self
.
_learn_model
.
forward
(
actor_data
,
mode
=
'compute_critic'
)[
'q_value'
].
mean
()
loss_dict
[
'actor_loss'
]
=
actor_loss
# actor update
self
.
_optimizer_actor
.
zero_grad
()
actor_loss
.
backward
()
self
.
_optimizer_actor
.
step
()
# =============
# after update
# =============
loss_dict
[
'total_loss'
]
=
sum
(
loss_dict
.
values
())
self
.
_forward_learn_cnt
+=
1
self
.
_target_model
.
update
(
self
.
_learn_model
.
state_dict
())
if
self
.
_cfg
.
action_space
==
'hybrid'
:
action_log_value
=
-
1.
# TODO(nyz) better way to viz hybrid action
else
:
action_log_value
=
data
[
'action'
].
mean
()
return
{
'cur_lr_actor'
:
self
.
_optimizer_actor
.
defaults
[
'lr'
],
'cur_lr_critic'
:
self
.
_optimizer_critic
.
defaults
[
'lr'
],
# 'q_value': np.array(q_value).mean(),
'action'
:
action_log_value
,
'priority'
:
td_error_per_sample
.
abs
().
tolist
(),
'td_error'
:
td_error_per_sample
.
abs
().
mean
(),
**
loss_dict
,
**
q_value_dict
,
}
q_value_dict
[
'q_value'
]
=
q_value
.
mean
()
# target q value.
with
torch
.
no_grad
():
next_actor_data
=
self
.
_target_model
.
forward
(
next_obs
,
mode
=
'compute_actor'
)
# latent action
next_actor_data
[
'obs'
]
=
next_obs
target_q_value
=
self
.
_target_model
.
forward
(
next_actor_data
,
mode
=
'compute_critic'
)[
'q_value'
]
if
self
.
_twin_critic
:
# TD3: two critic networks
target_q_value
=
torch
.
min
(
target_q_value
[
0
],
target_q_value
[
1
])
# find min one as target q value
# critic network1
td_data
=
v_1step_td_data
(
q_value
[
0
],
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_loss
,
td_error_per_sample1
=
v_1step_td_error
(
td_data
,
self
.
_gamma
)
loss_dict
[
'critic_loss'
]
=
critic_loss
# critic network2(twin network)
td_data_twin
=
v_1step_td_data
(
q_value
[
1
],
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_twin_loss
,
td_error_per_sample2
=
v_1step_td_error
(
td_data_twin
,
self
.
_gamma
)
loss_dict
[
'critic_twin_loss'
]
=
critic_twin_loss
td_error_per_sample
=
(
td_error_per_sample1
+
td_error_per_sample2
)
/
2
else
:
# DDPG: single critic network
td_data
=
v_1step_td_data
(
q_value
,
target_q_value
,
reward
,
data
[
'done'
],
data
[
'weight'
])
critic_loss
,
td_error_per_sample
=
v_1step_td_error
(
td_data
,
self
.
_gamma
)
loss_dict
[
'critic_loss'
]
=
critic_loss
# ================
# critic update
# ================
self
.
_optimizer_critic
.
zero_grad
()
for
k
in
loss_dict
:
if
'critic'
in
k
:
loss_dict
[
k
].
backward
()
self
.
_optimizer_critic
.
step
()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if
(
self
.
_forward_learn_cnt
+
1
)
%
self
.
_actor_update_freq
==
0
:
actor_data
=
self
.
_learn_model
.
forward
(
data
[
'obs'
],
mode
=
'compute_actor'
)
# latent action
actor_data
[
'obs'
]
=
data
[
'obs'
]
if
self
.
_twin_critic
:
actor_loss
=
-
self
.
_learn_model
.
forward
(
actor_data
,
mode
=
'compute_critic'
)[
'q_value'
][
0
].
mean
()
else
:
actor_loss
=
-
self
.
_learn_model
.
forward
(
actor_data
,
mode
=
'compute_critic'
)[
'q_value'
].
mean
()
loss_dict
[
'actor_loss'
]
=
actor_loss
# actor update
self
.
_optimizer_actor
.
zero_grad
()
actor_loss
.
backward
()
self
.
_optimizer_actor
.
step
()
# =============
# after update
# =============
loss_dict
[
'total_loss'
]
=
sum
(
loss_dict
.
values
())
# self._forward_learn_cnt += 1
self
.
_target_model
.
update
(
self
.
_learn_model
.
state_dict
())
if
self
.
_cfg
.
action_space
==
'hybrid'
:
action_log_value
=
-
1.
# TODO(nyz) better way to viz hybrid action
else
:
action_log_value
=
data
[
'action'
].
mean
()
return
{
'cur_lr_actor'
:
self
.
_optimizer_actor
.
defaults
[
'lr'
],
'cur_lr_critic'
:
self
.
_optimizer_critic
.
defaults
[
'lr'
],
# 'q_value': np.array(q_value).mean(),
'action'
:
action_log_value
,
'priority'
:
td_error_per_sample
.
abs
().
tolist
(),
'td_error'
:
td_error_per_sample
.
abs
().
mean
(),
**
loss_dict
,
**
q_value_dict
,
}
def
_state_dict_learn
(
self
)
->
Dict
[
str
,
Any
]:
return
{
...
...
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
浏览文件 @
4fca50f4
...
...
@@ -2,7 +2,8 @@ from easydict import EasyDict
from
ding.entry
import
serial_pipeline_td3_vae
lunarlander_td3vae_config
=
dict
(
exp_name
=
'lunarlander_cont_td3_vae'
,
# exp_name='lunarlander_cont_td3_vae_wu0_vae5rl5_tvtpc1',
exp_name
=
'lunarlander_cont_td3_vae_wu0_vae5rl5_tvtpc5'
,
env
=
dict
(
env_id
=
'LunarLanderContinuous-v2'
,
# collector_env_num=8,
...
...
@@ -17,18 +18,24 @@ lunarlander_td3vae_config = dict(
policy
=
dict
(
cuda
=
False
,
priority
=
False
,
random_collect_size
=
12800
,
# random_collect_size=1280,
random_collect_size
=
0
,
original_action_shape
=
2
,
model
=
dict
(
obs_shape
=
8
,
action_shape
=
64
,
# latent_action
_shape
action_shape
=
2
,
# 64, # action_latent
_shape
twin_critic
=
True
,
actor_head_type
=
'regression'
,
),
learn
=
dict
(
# warm_up_update=1,
warm_up_update
=
1000
,
update_per_collect
=
2
,
warm_up_update
=
0
,
# warm_up_update=100,
vae_update_freq
=
10
,
# TODO(pu)
rl_update_freq
=
10
,
train_vae_times_per_update
=
5
,
# TODO(pu)
update_per_collect
=
10
,
# train vae 5 times, rl 5 times
batch_size
=
128
,
learning_rate_actor
=
0.001
,
learning_rate_critic
=
0.001
,
...
...
@@ -44,7 +51,9 @@ lunarlander_td3vae_config = dict(
),
collect
=
dict
(
# n_sample=48,
each_iter_n_sample
=
48
,
# each_iter_n_sample=48,
# each_iter_n_sample=128,
each_iter_n_sample
=
256
,
noise_sigma
=
0.1
,
collector
=
dict
(
collect_print_freq
=
1000
,
),
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录