Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
f4440650
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f4440650
编写于
7月 23, 2021
作者:
Z
zhangyinmin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the unittest for the gae; format code.
上级
e30a3d3c
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
55 addition
and
101 deletion
+55
-101
ding/model/common/head.py
ding/model/common/head.py
+10
-10
ding/model/template/vac.py
ding/model/template/vac.py
+14
-14
ding/policy/ppo.py
ding/policy/ppo.py
+5
-5
ding/rl_utils/ppo.py
ding/rl_utils/ppo.py
+1
-1
ding/rl_utils/tests/test_gae.py
ding/rl_utils/tests/test_gae.py
+4
-2
dizoo/mujoco/config/ant_ddpg_default_config.py
dizoo/mujoco/config/ant_ddpg_default_config.py
+2
-6
dizoo/mujoco/config/ant_sac_default_config.py
dizoo/mujoco/config/ant_sac_default_config.py
+1
-5
dizoo/mujoco/config/ant_td3_default_config.py
dizoo/mujoco/config/ant_td3_default_config.py
+2
-6
dizoo/mujoco/config/halfcheetah_ddpg_default_config.py
dizoo/mujoco/config/halfcheetah_ddpg_default_config.py
+2
-6
dizoo/mujoco/config/halfcheetah_sac_default_config.py
dizoo/mujoco/config/halfcheetah_sac_default_config.py
+1
-5
dizoo/mujoco/config/halfcheetah_td3_default_config.py
dizoo/mujoco/config/halfcheetah_td3_default_config.py
+2
-6
dizoo/mujoco/config/hopper_ddpg_default_config.py
dizoo/mujoco/config/hopper_ddpg_default_config.py
+2
-6
dizoo/mujoco/config/hopper_ppo_default_config.py
dizoo/mujoco/config/hopper_ppo_default_config.py
+1
-1
dizoo/mujoco/config/hopper_sac_default_config.py
dizoo/mujoco/config/hopper_sac_default_config.py
+1
-5
dizoo/mujoco/config/hopper_td3_default_config.py
dizoo/mujoco/config/hopper_td3_default_config.py
+2
-6
dizoo/mujoco/config/walker2d_ddpg_default_config.py
dizoo/mujoco/config/walker2d_ddpg_default_config.py
+2
-6
dizoo/mujoco/config/walker2d_sac_default_config.py
dizoo/mujoco/config/walker2d_sac_default_config.py
+1
-5
dizoo/mujoco/config/walker2d_td3_default_config.py
dizoo/mujoco/config/walker2d_td3_default_config.py
+2
-6
未找到文件。
ding/model/common/head.py
浏览文件 @
f4440650
...
...
@@ -603,15 +603,15 @@ class ReparameterizationHead(nn.Module):
default_bound_type
=
[
'tanh'
,
None
]
def
__init__
(
self
,
hidden_size
:
int
,
output_size
:
int
,
layer_num
:
int
=
2
,
sigma_type
:
Optional
[
str
]
=
None
,
fixed_sigma_value
:
Optional
[
float
]
=
1.0
,
activation
:
Optional
[
nn
.
Module
]
=
nn
.
ReLU
(),
norm_type
:
Optional
[
str
]
=
None
,
bound_type
:
Optional
[
str
]
=
None
,
self
,
hidden_size
:
int
,
output_size
:
int
,
layer_num
:
int
=
2
,
sigma_type
:
Optional
[
str
]
=
None
,
fixed_sigma_value
:
Optional
[
float
]
=
1.0
,
activation
:
Optional
[
nn
.
Module
]
=
nn
.
ReLU
(),
norm_type
:
Optional
[
str
]
=
None
,
bound_type
:
Optional
[
str
]
=
None
,
)
->
None
:
r
"""
Overview:
...
...
@@ -672,7 +672,7 @@ class ReparameterizationHead(nn.Module):
"""
x
=
self
.
main
(
x
)
mu
=
self
.
mu
(
x
)
if
self
.
bound_type
==
'tanh'
:
if
self
.
bound_type
==
'tanh'
:
mu
=
torch
.
tanh
(
mu
)
if
self
.
sigma_type
==
'fixed'
:
sigma
=
self
.
sigma
.
to
(
mu
.
device
)
+
torch
.
zeros_like
(
mu
)
# addition aims to broadcast shape
...
...
ding/model/template/vac.py
浏览文件 @
f4440650
...
...
@@ -18,20 +18,20 @@ class VAC(nn.Module):
mode
=
[
'compute_actor'
,
'compute_critic'
,
'compute_actor_critic'
]
def
__init__
(
self
,
obs_shape
:
Union
[
int
,
SequenceType
],
action_shape
:
Union
[
int
,
SequenceType
],
share_encoder
:
bool
=
True
,
continuous
:
bool
=
False
,
encoder_hidden_size_list
:
SequenceType
=
[
128
,
128
,
64
],
actor_head_hidden_size
:
int
=
64
,
actor_head_layer_num
:
int
=
1
,
critic_head_hidden_size
:
int
=
64
,
critic_head_layer_num
:
int
=
1
,
activation
:
Optional
[
nn
.
Module
]
=
nn
.
ReLU
(),
norm_type
:
Optional
[
str
]
=
None
,
sigma_type
:
Optional
[
str
]
=
'independent'
,
bound_type
:
Optional
[
str
]
=
None
,
self
,
obs_shape
:
Union
[
int
,
SequenceType
],
action_shape
:
Union
[
int
,
SequenceType
],
share_encoder
:
bool
=
True
,
continuous
:
bool
=
False
,
encoder_hidden_size_list
:
SequenceType
=
[
128
,
128
,
64
],
actor_head_hidden_size
:
int
=
64
,
actor_head_layer_num
:
int
=
1
,
critic_head_hidden_size
:
int
=
64
,
critic_head_layer_num
:
int
=
1
,
activation
:
Optional
[
nn
.
Module
]
=
nn
.
ReLU
(),
norm_type
:
Optional
[
str
]
=
None
,
sigma_type
:
Optional
[
str
]
=
'independent'
,
bound_type
:
Optional
[
str
]
=
None
,
)
->
None
:
r
"""
Overview:
...
...
ding/policy/ppo.py
浏览文件 @
f4440650
...
...
@@ -131,7 +131,6 @@ class PPOPolicy(Policy):
# Main model
self
.
_learn_model
.
reset
()
def
_forward_learn
(
self
,
data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Overview:
...
...
@@ -192,8 +191,8 @@ class PPOPolicy(Policy):
# Calculate ppo error
if
self
.
_continuous
:
ppo_batch
=
ppo_data
(
output
[
'logit'
],
batch
[
'logit'
],
batch
[
'action'
],
output
[
'value'
],
batch
[
'value'
],
adv
,
batch
[
'return'
],
batch
[
'weight'
]
output
[
'logit'
],
batch
[
'logit'
],
batch
[
'action'
],
output
[
'value'
],
batch
[
'value'
],
adv
,
batch
[
'return'
],
batch
[
'weight'
]
)
ppo_loss
,
ppo_info
=
ppo_error_continuous
(
ppo_batch
,
self
.
_clip_ratio
)
else
:
...
...
@@ -333,8 +332,9 @@ class PPOPolicy(Policy):
last_value
*=
self
.
_running_mean_std
.
std
for
i
in
range
(
len
(
data
)):
data
[
i
][
'value'
]
*=
self
.
_running_mean_std
.
std
data
=
get_gae
(
data
,
to_device
(
last_value
,
self
.
_device
),
gamma
=
self
.
_gamma
,
gae_lambda
=
self
.
_gae_lambda
,
cuda
=
self
.
_cuda
)
data
=
get_gae
(
data
,
to_device
(
last_value
,
self
.
_device
),
gamma
=
self
.
_gamma
,
gae_lambda
=
self
.
_gae_lambda
,
cuda
=
self
.
_cuda
)
if
self
.
_value_norm
:
for
i
in
range
(
len
(
data
)):
data
[
i
][
'value'
]
/=
self
.
_running_mean_std
.
std
...
...
ding/rl_utils/ppo.py
浏览文件 @
f4440650
...
...
@@ -171,7 +171,7 @@ def ppo_error_continuous(
weight
=
torch
.
ones_like
(
adv
)
dist_new
=
Independent
(
Normal
(
mu_sigma_new
[
0
],
mu_sigma_new
[
1
]),
1
)
if
len
(
mu_sigma_old
[
0
].
shape
)
==
1
:
if
len
(
mu_sigma_old
[
0
].
shape
)
==
1
:
dist_old
=
Independent
(
Normal
(
mu_sigma_old
[
0
].
unsqueeze
(
-
1
),
mu_sigma_old
[
1
].
unsqueeze
(
-
1
)),
1
)
else
:
dist_old
=
Independent
(
Normal
(
mu_sigma_old
[
0
],
mu_sigma_old
[
1
]),
1
)
...
...
ding/rl_utils/tests/test_gae.py
浏览文件 @
f4440650
...
...
@@ -6,8 +6,10 @@ from ding.rl_utils import gae_data, gae
@
pytest
.
mark
.
unittest
def
test_gae
():
T
,
B
=
32
,
4
value
=
torch
.
randn
(
T
+
1
,
B
)
value
=
torch
.
randn
(
T
,
B
)
next_value
=
torch
.
randn
(
T
,
B
)
reward
=
torch
.
randn
(
T
,
B
)
data
=
gae_data
(
value
,
reward
)
done
=
torch
.
zeros
((
T
,
B
))
data
=
gae_data
(
value
,
next_value
,
reward
,
done
)
adv
=
gae
(
data
)
assert
adv
.
shape
==
(
T
,
B
)
dizoo/mujoco/config/ant_ddpg_default_config.py
浏览文件 @
f4440650
...
...
@@ -39,11 +39,7 @@ ant_ddpg_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
ant_ddpg_default_config
=
EasyDict
(
ant_ddpg_default_config
)
...
...
@@ -59,7 +55,7 @@ ant_ddpg_default_create_config = dict(
type
=
'ddpg'
,
import_names
=
[
'ding.policy.ddpg'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
ant_ddpg_default_create_config
=
EasyDict
(
ant_ddpg_default_create_config
)
create_config
=
ant_ddpg_default_create_config
dizoo/mujoco/config/ant_sac_default_config.py
浏览文件 @
f4440650
...
...
@@ -42,11 +42,7 @@ ant_sac_default_config = dict(
),
command
=
dict
(),
eval
=
dict
(),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
),
)
...
...
dizoo/mujoco/config/ant_td3_default_config.py
浏览文件 @
f4440650
...
...
@@ -44,11 +44,7 @@ ant_td3_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
...
...
@@ -65,7 +61,7 @@ ant_td3_default_create_config = dict(
policy_type
=
'td3'
,
import_names
=
[
'ding.policy.td3'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
ant_td3_default_create_config
=
EasyDict
(
ant_td3_default_create_config
)
create_config
=
ant_td3_default_create_config
dizoo/mujoco/config/halfcheetah_ddpg_default_config.py
浏览文件 @
f4440650
...
...
@@ -39,11 +39,7 @@ halfcheetah_ddpg_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
halfcheetah_ddpg_default_config
=
EasyDict
(
halfcheetah_ddpg_default_config
)
...
...
@@ -59,7 +55,7 @@ halfcheetah_ddpg_default_create_config = dict(
type
=
'ddpg'
,
import_names
=
[
'ding.policy.ddpg'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
halfcheetah_ddpg_default_create_config
=
EasyDict
(
halfcheetah_ddpg_default_create_config
)
create_config
=
halfcheetah_ddpg_default_create_config
dizoo/mujoco/config/halfcheetah_sac_default_config.py
浏览文件 @
f4440650
...
...
@@ -42,11 +42,7 @@ halfcheetah_sac_default_config = dict(
),
command
=
dict
(),
eval
=
dict
(),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
),
)
...
...
dizoo/mujoco/config/halfcheetah_td3_default_config.py
浏览文件 @
f4440650
...
...
@@ -44,11 +44,7 @@ halfcheetah_td3_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
...
...
@@ -65,7 +61,7 @@ halfcheetah_td3_default_create_config = dict(
policy_type
=
'td3'
,
import_names
=
[
'ding.policy.td3'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
halfcheetah_td3_default_create_config
=
EasyDict
(
halfcheetah_td3_default_create_config
)
create_config
=
halfcheetah_td3_default_create_config
dizoo/mujoco/config/hopper_ddpg_default_config.py
浏览文件 @
f4440650
...
...
@@ -39,11 +39,7 @@ hopper_ddpg_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
hopper_ddpg_default_config
=
EasyDict
(
hopper_ddpg_default_config
)
...
...
@@ -59,7 +55,7 @@ hopper_ddpg_default_create_config = dict(
type
=
'ddpg'
,
import_names
=
[
'ding.policy.ddpg'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
hopper_ddpg_default_create_config
=
EasyDict
(
hopper_ddpg_default_create_config
)
create_config
=
hopper_ddpg_default_create_config
dizoo/mujoco/config/hopper_ppo_default_config.py
浏览文件 @
f4440650
...
...
@@ -57,7 +57,7 @@ hopper_ppo_create_default_config = dict(
type
=
'ppo'
,
import_names
=
[
'ding.policy.ppo'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
hopper_ppo_create_default_config
=
EasyDict
(
hopper_ppo_create_default_config
)
create_config
=
hopper_ppo_create_default_config
dizoo/mujoco/config/hopper_sac_default_config.py
浏览文件 @
f4440650
...
...
@@ -42,11 +42,7 @@ hopper_sac_default_config = dict(
),
command
=
dict
(),
eval
=
dict
(),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
),
)
...
...
dizoo/mujoco/config/hopper_td3_default_config.py
浏览文件 @
f4440650
...
...
@@ -44,11 +44,7 @@ hopper_td3_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
...
...
@@ -65,7 +61,7 @@ hopper_td3_default_create_config = dict(
policy_type
=
'td3'
,
import_names
=
[
'ding.policy.td3'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
hopper_td3_default_create_config
=
EasyDict
(
hopper_td3_default_create_config
)
create_config
=
hopper_td3_default_create_config
dizoo/mujoco/config/walker2d_ddpg_default_config.py
浏览文件 @
f4440650
...
...
@@ -39,11 +39,7 @@ walker2d_ddpg_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
walker2d_ddpg_default_config
=
EasyDict
(
walker2d_ddpg_default_config
)
...
...
@@ -59,7 +55,7 @@ walker2d_ddpg_default_create_config = dict(
type
=
'ddpg'
,
import_names
=
[
'ding.policy.ddpg'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
walker2d_ddpg_default_create_config
=
EasyDict
(
walker2d_ddpg_default_create_config
)
create_config
=
walker2d_ddpg_default_create_config
dizoo/mujoco/config/walker2d_sac_default_config.py
浏览文件 @
f4440650
...
...
@@ -42,11 +42,7 @@ walker2d_sac_default_config = dict(
),
command
=
dict
(),
eval
=
dict
(),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
),
)
...
...
dizoo/mujoco/config/walker2d_td3_default_config.py
浏览文件 @
f4440650
...
...
@@ -44,11 +44,7 @@ walker2d_td3_default_config = dict(
unroll_len
=
1
,
noise_sigma
=
0.1
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
)
)
...
...
@@ -65,7 +61,7 @@ walker2d_td3_default_create_config = dict(
policy_type
=
'td3'
,
import_names
=
[
'ding.policy.td3'
],
),
replay_buffer
=
dict
(
type
=
'naive'
,),
replay_buffer
=
dict
(
type
=
'naive'
,
),
)
walker2d_td3_default_create_config
=
EasyDict
(
walker2d_td3_default_create_config
)
create_config
=
walker2d_td3_default_create_config
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录