Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
293e5c36
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
293e5c36
编写于
12月 27, 2021
作者:
P
puyuan1996
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(pu): update the current best config
上级
70328aab
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
11 addition
and
11 deletion
+11
-11
ding/model/template/vae.py
ding/model/template/vae.py
+6
-7
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
...x2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
+5
-4
未找到文件。
ding/model/template/vae.py
浏览文件 @
293e5c36
...
...
@@ -56,18 +56,17 @@ class VanillaVAE(BaseVAE):
# obs
self
.
obs_head
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
obs_dim
,
hidden_dims
[
0
]),
nn
.
ReLU
())
self
.
encoder
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
hidden_dims
[
0
]),
nn
.
ReLU
())
self
.
encoder
_common
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
hidden_dims
[
0
]),
nn
.
ReLU
())
self
.
mu_head
=
nn
.
Linear
(
hidden_dims
[
0
],
latent_dim
)
self
.
var_head
=
nn
.
Linear
(
hidden_dims
[
0
],
latent_dim
)
self
.
log
var_head
=
nn
.
Linear
(
hidden_dims
[
0
],
latent_dim
)
# Build Decoder
self
.
condition_obs
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
obs_dim
,
hidden_dims
[
0
]),
nn
.
ReLU
())
self
.
decoder_action
=
nn
.
Sequential
(
nn
.
Linear
(
latent_dim
,
hidden_dims
[
0
]),
nn
.
ReLU
())
self
.
decoder_common
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
hidden_dims
[
0
]),
nn
.
ReLU
())
# TODO(pu): tanh
self
.
reconstruction_layer
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
self
.
action_dim
),
nn
.
Tanh
())
# self.reconstruction_layer = nn.Linear(hidden_dims[0], self.action_dim)
self
.
reconstruction_head
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
self
.
action_dim
),
nn
.
Tanh
())
# self.reconstruction_head = nn.Linear(hidden_dims[0], self.action_dim)
# residual prediction
self
.
prediction_head_1
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_dims
[
0
],
hidden_dims
[
0
]),
nn
.
ReLU
())
...
...
@@ -91,13 +90,13 @@ class VanillaVAE(BaseVAE):
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input
=
obs_encoding
*
action_encoding
result
=
self
.
encoder
(
input
)
result
=
self
.
encoder
_common
(
input
)
result
=
torch
.
flatten
(
result
,
start_dim
=
1
)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu
=
self
.
mu_head
(
result
)
log_var
=
self
.
var_head
(
result
)
log_var
=
self
.
log
var_head
(
result
)
return
[
mu
,
log_var
]
...
...
dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
浏览文件 @
293e5c36
...
...
@@ -2,10 +2,10 @@ from easydict import EasyDict
from
ding.entry
import
serial_pipeline_td3_vae
lunarlander_td3vae_config
=
dict
(
exp_name
=
'lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc
3_upcr256_upcv10
0_kw0.01_pw0.01_dot_tanh'
,
exp_name
=
'lunarlander_cont_td3_vae_lad6_rcs1e4_wu1e4_ns256_bs128_auf2_targetnoise_collectoriginalnoise_rbs1e5_rsc_lsc_rvuc
1_upcr256_upcv1
0_kw0.01_pw0.01_dot_tanh'
,
env
=
dict
(
env_id
=
'LunarLanderContinuous-v2'
,
collector_env_num
=
8
,
collector_env_num
=
1
,
evaluator_env_num
=
5
,
# (bool) Scale output action into legal range.
act_scale
=
True
,
...
...
@@ -25,9 +25,10 @@ lunarlander_td3vae_config = dict(
),
learn
=
dict
(
warm_up_update
=
int
(
1e4
),
rl_vae_update_circle
=
3
,
# train rl 3
iter, vae 1 iter
rl_vae_update_circle
=
1
,
# train rl 1
iter, vae 1 iter
update_per_collect_rl
=
256
,
update_per_collect_vae
=
100
,
update_per_collect_vae
=
10
,
batch_size
=
128
,
learning_rate_actor
=
3e-4
,
learning_rate_critic
=
3e-4
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录