Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
28930a86
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
60
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
28930a86
编写于
11月 01, 2021
作者:
N
niuyazhe
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): fix r2d2 and dqtd error unittest bug
上级
286ea243
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
6 addition
and
3 deletion
+6
-3
ding/entry/tests/test_serial_entry.py
ding/entry/tests/test_serial_entry.py
+1
-1
ding/policy/r2d2.py
ding/policy/r2d2.py
+1
-1
ding/rl_utils/tests/test_adder.py
ding/rl_utils/tests/test_adder.py
+2
-0
ding/rl_utils/tests/test_td.py
ding/rl_utils/tests/test_td.py
+1
-1
dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py
...o/classic_control/cartpole/config/cartpole_r2d2_config.py
+1
-0
未找到文件。
ding/entry/tests/test_serial_entry.py
浏览文件 @
28930a86
...
...
@@ -188,7 +188,7 @@ def test_r2d2():
config
=
[
deepcopy
(
cartpole_r2d2_config
),
deepcopy
(
cartpole_r2d2_create_config
)]
config
[
0
].
policy
.
learn
.
update_per_collect
=
1
try
:
serial_pipeline
(
config
,
seed
=
0
,
max_iterations
=
1
)
serial_pipeline
(
config
,
seed
=
0
,
max_iterations
=
5
)
except
Exception
:
assert
False
,
"pipeline fail"
...
...
ding/policy/r2d2.py
浏览文件 @
28930a86
...
...
@@ -223,7 +223,7 @@ class R2D2Policy(Policy):
else
:
data
[
'value_gamma'
]
=
data
[
'value_gamma'
][
bs
:]
if
'weight'
not
in
data
:
if
'weight'
not
in
data
or
data
[
'weight'
]
is
None
:
data
[
'weight'
]
=
[
None
for
_
in
range
(
self
.
_unroll_len_add_burnin_step
-
bs
)]
else
:
data
[
'weight'
]
=
data
[
'weight'
]
*
torch
.
ones_like
(
data
[
'done'
])
...
...
ding/rl_utils/tests/test_adder.py
浏览文件 @
28930a86
...
...
@@ -13,6 +13,7 @@ class TestAdder:
return
{
'value'
:
torch
.
randn
(
1
),
'reward'
:
torch
.
rand
(
1
),
'action'
:
torch
.
rand
(
3
),
'other'
:
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
)),
'obs'
:
torch
.
randn
(
3
),
'done'
:
False
...
...
@@ -22,6 +23,7 @@ class TestAdder:
return
{
'value'
:
torch
.
randn
(
1
,
8
),
'reward'
:
torch
.
rand
(
1
,
1
),
'action'
:
torch
.
rand
(
3
),
'other'
:
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
)),
'obs'
:
torch
.
randn
(
3
),
'done'
:
False
...
...
ding/rl_utils/tests/test_td.py
浏览文件 @
28930a86
...
...
@@ -234,7 +234,7 @@ def test_dqfd_nstep_td():
q
,
next_q
,
action
,
next_action
,
reward
,
done
,
done_1
,
None
,
next_q_one_step
,
next_action_one_step
,
is_expert
)
loss
,
td_error_per_sample
=
dqfd_nstep_td_error
(
data
,
0.95
,
lambda
1
=
1
,
lambda2
=
1
,
margin_function
=
0.8
,
nstep
=
nstep
data
,
0.95
,
lambda
_n_step_td
=
1
,
lambda_supervised_loss
=
1
,
margin_function
=
0.8
,
nstep
=
nstep
)
assert
td_error_per_sample
.
shape
==
(
batch_size
,
)
assert
loss
.
shape
==
()
...
...
dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py
浏览文件 @
28930a86
...
...
@@ -14,6 +14,7 @@ cartpole_r2d2_config = dict(
policy
=
dict
(
cuda
=
False
,
priority
=
False
,
priority_IS_weight
=
False
,
model
=
dict
(
obs_shape
=
4
,
action_shape
=
2
,
...
...
OpenDILab开源决策智能平台
@m0_55289267
mentioned in commit
707ea71c
·
11月 02, 2021
mentioned in commit
707ea71c
mentioned in commit 707ea71cc1835bbc596ea91c9ad16b913041ca24
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录