Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
234de26b
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
234de26b
编写于
12月 08, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): fix trex unittest bugs
上级
63105fef
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
12 addition
and
10 deletion
+12
-10
ding/entry/application_entry.py
ding/entry/application_entry.py
+8
-4
ding/entry/tests/test_application_entry_trex_collect_data.py
ding/entry/tests/test_application_entry_trex_collect_data.py
+2
-4
ding/entry/tests/test_serial_entry_il.py
ding/entry/tests/test_serial_entry_il.py
+1
-1
ding/entry/tests/test_serial_entry_trex_onpolicy.py
ding/entry/tests/test_serial_entry_trex_onpolicy.py
+1
-1
未找到文件。
ding/entry/application_entry.py
浏览文件 @
234de26b
...
...
@@ -144,8 +144,10 @@ def collect_demo_data(
policy
.
collect_mode
.
load_state_dict
(
state_dict
)
collector
=
SampleSerialCollector
(
cfg
.
policy
.
collect
.
collector
,
collector_env
,
collect_demo_policy
)
policy_kwargs
=
None
if
not
hasattr
(
cfg
.
policy
.
other
,
'eps'
)
\
else
{
'eps'
:
cfg
.
policy
.
other
.
eps
.
get
(
'collect'
,
0.2
)}
if
hasattr
(
cfg
.
policy
.
other
,
'eps'
):
policy_kwargs
=
{
'eps'
:
0.
}
else
:
policy_kwargs
=
None
# Let's collect some expert demonstrations
exp_data
=
collector
.
collect
(
n_sample
=
collect_count
,
policy_kwargs
=
policy_kwargs
)
...
...
@@ -215,8 +217,10 @@ def collect_episodic_demo_data(
policy
.
collect_mode
.
load_state_dict
(
state_dict
)
collector
=
EpisodeSerialCollector
(
cfg
.
policy
.
collect
.
collector
,
collector_env
,
collect_demo_policy
)
policy_kwargs
=
None
if
not
hasattr
(
cfg
.
policy
.
other
,
'eps'
)
\
else
{
'eps'
:
cfg
.
policy
.
other
.
eps
.
get
(
'collect'
,
0.2
)}
if
hasattr
(
cfg
.
policy
.
other
,
'eps'
):
policy_kwargs
=
{
'eps'
:
0.
}
else
:
policy_kwargs
=
None
# Let's collect some expert demostrations
exp_data
=
collector
.
collect
(
n_episode
=
collect_count
,
policy_kwargs
=
policy_kwargs
)
...
...
ding/entry/tests/test_application_entry_trex_collect_data.py
浏览文件 @
234de26b
...
...
@@ -17,7 +17,6 @@ from ding.entry import serial_pipeline
@
pytest
.
mark
.
unittest
def
test_collect_episodic_demo_data_for_trex
():
expert_policy_state_dict_path
=
'./expert_policy.pth'
expert_policy_state_dict_path
=
os
.
path
.
abspath
(
'ding/entry/expert_policy.pth'
)
config
=
[
deepcopy
(
cartpole_ppo_offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy_create_config
)]
expert_policy
=
serial_pipeline
(
config
,
seed
=
0
)
torch
.
save
(
expert_policy
.
collect_mode
.
state_dict
(),
expert_policy_state_dict_path
)
...
...
@@ -40,7 +39,7 @@ def test_collect_episodic_demo_data_for_trex():
os
.
popen
(
'rm -rf {}'
.
format
(
expert_policy_state_dict_path
))
@
pytest
.
mark
.
unittest
#
@pytest.mark.unittest
def
test_trex_collecting_data
():
expert_policy_state_dict_path
=
'./cartpole_ppo_offpolicy'
expert_policy_state_dict_path
=
os
.
path
.
abspath
(
expert_policy_state_dict_path
)
...
...
@@ -55,10 +54,9 @@ def test_trex_collecting_data():
'device'
:
'cpu'
}
)
args
.
cfg
[
0
].
reward_model
.
offline_data_path
=
'
dizoo/classic_control/cartpole/config/cartpole_trex_offppo
'
args
.
cfg
[
0
].
reward_model
.
offline_data_path
=
'
cartpole_trex_offppo_offline_data
'
args
.
cfg
[
0
].
reward_model
.
offline_data_path
=
os
.
path
.
abspath
(
args
.
cfg
[
0
].
reward_model
.
offline_data_path
)
args
.
cfg
[
0
].
reward_model
.
reward_model_path
=
args
.
cfg
[
0
].
reward_model
.
offline_data_path
+
'/cartpole.params'
args
.
cfg
[
0
].
reward_model
.
expert_model_path
=
'./cartpole_ppo_offpolicy'
args
.
cfg
[
0
].
reward_model
.
expert_model_path
=
os
.
path
.
abspath
(
args
.
cfg
[
0
].
reward_model
.
expert_model_path
)
trex_collecting_data
(
args
=
args
)
os
.
popen
(
'rm -rf {}'
.
format
(
expert_policy_state_dict_path
))
...
...
ding/entry/tests/test_serial_entry_il.py
浏览文件 @
234de26b
...
...
@@ -86,7 +86,7 @@ class DQNILPolicy(ILPolicy):
self
.
_collect_model
=
model_wrap
(
self
.
_model
,
wrapper_name
=
'argmax_sample'
)
self
.
_collect_model
.
reset
()
def
_forward_collect
(
self
,
data
:
dict
):
def
_forward_collect
(
self
,
data
:
dict
,
eps
:
float
):
data_id
=
list
(
data
.
keys
())
data
=
default_collate
(
list
(
data
.
values
()))
if
self
.
_cuda
:
...
...
ding/entry/tests/test_serial_entry_trex_onpolicy.py
浏览文件 @
234de26b
...
...
@@ -12,7 +12,7 @@ from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_
from
ding.entry.application_entry_trex_collect_data
import
trex_collecting_data
@
pytest
.
mark
.
unittest
#
@pytest.mark.unittest
def
test_serial_pipeline_reward_model_trex
():
config
=
[
deepcopy
(
hopper_ppo_default_config
),
deepcopy
(
hopper_ppo_create_default_config
)]
expert_policy
=
serial_pipeline_onpolicy
(
config
,
seed
=
0
,
max_iterations
=
90
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录