Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
7a4de160
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
60
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7a4de160
编写于
10月 19, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add naive offpolicy demo(ci skip)
上级
ad394fc5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
112 addition
and
0 deletion
+112
-0
ding/entry/offpolicy_demo.py
ding/entry/offpolicy_demo.py
+112
-0
未找到文件。
ding/entry/offpolicy_demo.py
0 → 100644
浏览文件 @
7a4de160
from
easydict
import
EasyDict
import
time
import
torch
import
treetensor.torch
as
ttorch
import
ding
from
ding
import
compile_config
,
set_pkg_seed
,
Policy
,
Pool
,
BaseLearner
,
SampleSerialCollector
,
InteractionSerialEvaluator
,
MemoryBuffer
,
VariableManager
,
Pipeline
from
.model
import
CustomizedModel
from
.env
import
CustomizedEnv
class
DQNPolicy
(
Policy
):
def
_init_learn
(
self
):
pass
def
_init_collect
(
self
):
pass
def
_init_eval
(
self
):
pass
def
forward_learn
(
self
,
learner
):
data
=
learner
.
train_data
log_info
=
{}
# TODO
return
log_info
def
forward_collect
(
self
,
collector
):
# TODO pass train_iter
obs
=
collector
.
env
.
obs
obs
=
ttorch
.
tensor
(
obs
)
if
self
.
_cuda
:
obs
=
obs
.
cuda
()
self
.
_collect_model
.
eval
()
with
torch
.
no_grad
():
eps
=
self
.
_eps_fn
(
collector
.
env_step
)
output
=
self
.
_collect_model
.
forward
(
obs
,
eps
=
eps
)
if
self
.
_cuda
:
output
=
output
.
cpu
()
action
=
output
.
action
timestep
=
collector
.
env
.
step
(
action
)
# TODO s_t+1 problem
train_timestep
=
self
.
_process_train_timestep
(
obs
,
output
,
timestep
)
# TODO async case
self
.
_reset_collect
(
timestep
.
env_id
,
timestep
.
done
)
log_info
=
{}
# TODO
return
train_timestep
,
timestep
,
log_info
def
forward_eval
(
self
,
evaluator
):
obs
=
evaluator
.
env
.
obs
obs
=
ttorch
.
tensor
(
obs
)
if
self
.
_cuda
:
obs
=
obs
.
cuda
()
self
.
_eval_model
.
eval
()
with
torch
.
no_grad
():
output
=
self
.
_eval_model
.
forward
(
obs
)
if
self
.
_cuda
:
output
=
output
.
cpu
()
action
=
output
.
action
timestep
=
evaluator
.
env
.
step
(
action
)
self
.
_reset_eval
(
timestep
.
env_id
,
timestep
.
done
)
log_info
=
{}
# TODO
return
timestep
,
log_info
class
OffPolicyTrainPipeline
(
Pipeline
):
def
__init__
(
self
,
cfg
:
EasyDict
,
env
,
policy
):
super
(
OffPolicyTrainPipeline
,
self
).
__init__
()
self
.
env
=
env
self
.
policy
=
policy
self
.
collector
=
SampleSerialCollector
(
cfg
.
collector
,
self
.
env
,
forward_fn
=
self
.
policy
.
forward_collect
)
self
.
learner
=
BaseLearner
(
cfg
.
learner
,
self
.
env
,
forward_fn
=
self
.
policy
.
forward_learn
)
self
.
buffer
=
MemoryBuffer
(
cfg
.
buffer
,
strategy_fn
=
self
.
policy
.
buffer_strategy
)
def
run
(
self
):
while
not
(
self
.
learner
.
stop
()
and
self
.
collector
.
stop
()):
self
.
collector
.
collect
(
self
.
buffer
)
data
=
self
.
buffer
.
sample
()
self
.
learner
.
learn
(
data
)
class
EvalPipeline
(
Pipeline
):
def
__init__
(
self
,
cfg
:
EasyDict
,
env
,
policy
):
super
(
EvalPipeline
,
self
).
__init__
()
self
.
env
=
env
self
.
policy
=
policy
self
.
evaluator
=
InteractionSerialEvaluator
(
cfg
.
evaluator
,
self
.
env
,
forward_fn
=
self
.
policy
.
forward_eval
)
def
run
(
self
):
while
True
:
if
self
.
evaluator
.
should_eval
():
stop_flag
,
reward
=
self
.
evaluator
.
eval
()
if
stop_flag
:
break
# TODO trigger save ckpt
# TODO shutdown the whole program
time
.
sleep
(
1
)
def
main
(
config_path
:
str
,
seed
:
int
):
cfg
=
compile_config
(
config_path
,
seed
)
set_pkg_seed
(
seed
)
model
=
CustomizedModel
(
cfg
.
policy
.
model
)
policy
=
DQNPolicy
(
cfg
.
policy
)
train_env
=
CustomizedEnv
(
cfg
.
env
,
seed
)
eval_env
=
CustomizedEnv
(
cfg
.
env
,
seed
)
# train_env = CustomizedEnv(cfg.env, seed).clone(8)
train_pipeline
=
OffPolicyTrainPipeline
(
cfg
,
train_env
,
policy
)
eval_pipeline
=
EvalPipeline
(
cfg
,
eval_env
,
policy
)
ding
.
run
([
train_pipeline
,
eval_pipeline
])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录