Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
dd6205b2
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
61
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dd6205b2
编写于
7月 16, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): fix parallel test exp_name bug
上级
79587654
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
20 addition
and
4 deletion
+20
-4
ding/config/config.py
ding/config/config.py
+1
-1
ding/config/tests/test_config_formatted.py
ding/config/tests/test_config_formatted.py
+1
-1
ding/config/utils.py
ding/config/utils.py
+1
-0
ding/worker/collector/tests/fake_cpong_dqn_config.py
ding/worker/collector/tests/fake_cpong_dqn_config.py
+1
-0
ding/worker/coordinator/base_parallel_commander.py
ding/worker/coordinator/base_parallel_commander.py
+3
-0
ding/worker/coordinator/tests/conftest.py
ding/worker/coordinator/tests/conftest.py
+1
-0
ding/worker/replay_buffer/tests/test_advanced_buffer.py
ding/worker/replay_buffer/tests/test_advanced_buffer.py
+10
-2
dizoo/classic_control/cartpole/config/cartpole_a2c_config.py
dizoo/classic_control/cartpole/config/cartpole_a2c_config.py
+1
-0
dizoo/classic_control/cartpole/config/cartpole_c51_config.py
dizoo/classic_control/cartpole/config/cartpole_c51_config.py
+1
-0
未找到文件。
ding/config/config.py
浏览文件 @
dd6205b2
...
...
@@ -153,7 +153,7 @@ def save_config_py(config_: dict, path: str) -> NoReturn:
config_string
,
_
=
FormatCode
(
config_string
)
config_string
=
config_string
.
replace
(
'inf'
,
'float("inf")'
)
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
'exp_config
=
'
+
config_string
)
f
.
write
(
'exp_config
=
'
+
config_string
)
def
read_config_directly
(
path
:
str
)
->
dict
:
...
...
ding/config/tests/test_config_formatted.py
浏览文件 @
dd6205b2
...
...
@@ -24,7 +24,7 @@ def test_config_formatted(config_path, name):
main_config
,
seed
=
0
,
auto
=
True
,
create_cfg
=
create_config
,
save_cfg
=
True
,
save_path
=
'{}_config.py'
.
format
(
name
)
)
module
=
importlib
.
import_module
(
'
formatted_{}_config'
.
format
(
name
))
module
=
importlib
.
import_module
(
'
cartpole_{}.formatted_{}_config'
.
format
(
name
,
name
))
main_config
,
create_config
=
module
.
main_config
,
module
.
create_config
cfg_test
=
compile_config
(
main_config
,
seed
=
0
,
auto
=
True
,
create_cfg
=
create_config
,
save_cfg
=
False
)
assert
cfg
==
cfg_test
,
'cfg_formatted_failed'
ding/config/utils.py
浏览文件 @
dd6205b2
...
...
@@ -251,6 +251,7 @@ def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py'
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
'from easydict import EasyDict
\n\n
'
)
f
.
write
(
'main_config = dict(
\n
'
)
f
.
write
(
" exp_name='{}',
\n
"
.
format
(
config_
.
exp_name
))
for
k
,
v
in
config_
.
items
():
if
(
k
==
'env'
):
f
.
write
(
' env=dict(
\n
'
)
...
...
ding/worker/collector/tests/fake_cpong_dqn_config.py
浏览文件 @
dd6205b2
...
...
@@ -2,6 +2,7 @@ from easydict import EasyDict
from
ding.config
import
parallel_transform
fake_cpong_dqn_config
=
dict
(
exp_name
=
'fake_cpong_dqn'
,
env
=
dict
(
collector_env_num
=
16
,
collector_episode_num
=
2
,
...
...
ding/worker/coordinator/base_parallel_commander.py
浏览文件 @
dd6205b2
...
...
@@ -61,6 +61,7 @@ class NaiveCommander(BaseCommander):
"collector_task_space" and "learner_task_space".
"""
self
.
_cfg
=
cfg
self
.
_exp_name
=
cfg
.
exp_name
commander_cfg
=
self
.
_cfg
.
policy
.
other
.
commander
self
.
_collector_task_space
=
LimitedSpaceContainer
(
0
,
commander_cfg
.
collector_task_space
)
self
.
_learner_task_space
=
LimitedSpaceContainer
(
0
,
commander_cfg
.
learner_task_space
)
...
...
@@ -91,6 +92,7 @@ class NaiveCommander(BaseCommander):
collector_cfg
.
policy
=
copy
.
deepcopy
(
self
.
_cfg
.
policy
)
collector_cfg
.
policy_update_path
=
'test.pth'
collector_cfg
.
env
=
self
.
_collector_env_cfg
collector_cfg
.
exp_name
=
self
.
_exp_name
return
{
'task_id'
:
'collector_task_id{}'
.
format
(
self
.
_collector_task_count
),
'buffer_id'
:
'test'
,
...
...
@@ -109,6 +111,7 @@ class NaiveCommander(BaseCommander):
if
self
.
_learner_task_space
.
acquire_space
():
self
.
_learner_task_count
+=
1
learner_cfg
=
copy
.
deepcopy
(
self
.
_cfg
.
policy
.
learn
.
learner
)
learner_cfg
.
exp_name
=
self
.
_exp_name
return
{
'task_id'
:
'learner_task_id{}'
.
format
(
self
.
_learner_task_count
),
'policy_id'
:
'test.pth'
,
...
...
ding/worker/coordinator/tests/conftest.py
浏览文件 @
dd6205b2
...
...
@@ -10,6 +10,7 @@ def setup_1v1commander():
nstep
=
1
eval_interval
=
5
main_config
=
dict
(
exp_name
=
'one_vs_one_test'
,
env
=
dict
(
collector_env_num
=
8
,
collector_episode_num
=
2
,
...
...
ding/worker/replay_buffer/tests/test_advanced_buffer.py
浏览文件 @
dd6205b2
...
...
@@ -247,8 +247,16 @@ class TestDemonstrationBuffer:
def
test_naive
(
self
,
setup_demo_buffer_factory
):
setup_demo_buffer
=
next
(
setup_demo_buffer_factory
)
naive_demo_buffer
=
next
(
setup_demo_buffer_factory
)
with
open
(
demo_data_path
,
'rb+'
)
as
f
:
data
=
pickle
.
load
(
f
)
while
True
:
with
open
(
demo_data_path
,
'rb+'
)
as
f
:
data
=
pickle
.
load
(
f
)
if
len
(
data
)
!=
0
:
break
else
:
# for the stability of dist-test
demo_data
=
{
'data'
:
generate_data_list
(
10
)}
with
open
(
demo_data_path
,
"wb"
)
as
f
:
pickle
.
dump
(
demo_data
,
f
)
setup_demo_buffer
.
load_state_dict
(
data
)
assert
setup_demo_buffer
.
count
()
==
len
(
data
[
'data'
])
# assert buffer not empty
samples
=
setup_demo_buffer
.
sample
(
3
,
0
)
...
...
dizoo/classic_control/cartpole/config/cartpole_a2c_config.py
浏览文件 @
dd6205b2
from
easydict
import
EasyDict
cartpole_a2c_config
=
dict
(
exp_name
=
'cartpole_a2c'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
5
,
...
...
dizoo/classic_control/cartpole/config/cartpole_c51_config.py
浏览文件 @
dd6205b2
from
easydict
import
EasyDict
cartpole_c51_config
=
dict
(
exp_name
=
'cartpole_c51'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
5
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录