Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
1438e760
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
60
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1438e760
编写于
11月 13, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add deque buffer compatibility wrapper and demo
上级
5c6df8b3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
115 addition
and
0 deletion
+115
-0
ding/worker/buffer/__init__.py
ding/worker/buffer/__init__.py
+1
-0
ding/worker/buffer/deque_buffer_wrapper.py
ding/worker/buffer/deque_buffer_wrapper.py
+34
-0
dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py
...lassic_control/cartpole/entry/cartpole_dqn_buffer_main.py
+80
-0
未找到文件。
ding/worker/buffer/__init__.py
浏览文件 @
1438e760
from
.buffer
import
Buffer
,
apply_middleware
,
BufferedData
from
.deque_buffer
import
DequeBuffer
from
.deque_buffer_wrapper
import
DequeBufferWrapper
ding/worker/buffer/deque_buffer_wrapper.py
0 → 100644
浏览文件 @
1438e760
from
typing
import
Optional
import
copy
from
easydict
import
EasyDict
from
ding.worker.buffer
import
DequeBuffer
from
ding.utils
import
BUFFER_REGISTRY
@
BUFFER_REGISTRY
.
register
(
'deque'
)
class
DequeBufferWrapper
(
object
):
@
classmethod
def
default_config
(
cls
:
type
)
->
EasyDict
:
cfg
=
EasyDict
(
copy
.
deepcopy
(
cls
.
config
))
cfg
.
cfg_type
=
cls
.
__name__
+
'Dict'
return
cfg
config
=
dict
(
replay_buffer_size
=
10000
,
)
def
__init__
(
self
,
cfg
:
EasyDict
,
tb_logger
:
Optional
[
object
]
=
None
,
exp_name
:
str
=
'default_experiement'
)
->
None
:
self
.
buffer
=
DequeBuffer
(
size
=
cfg
.
replay_buffer_size
)
def
sample
(
self
,
size
:
int
,
train_iter
:
int
):
output
=
self
.
buffer
.
sample
(
size
=
size
,
ignore_insufficient
=
True
)
if
len
(
output
)
>
0
:
return
[
o
.
data
for
o
in
output
]
else
:
return
None
def
push
(
self
,
data
,
cur_collector_envstep
:
int
=
-
1
)
->
None
:
# meta = {'train_iter_data_collected': }
for
d
in
data
:
self
.
buffer
.
push
(
d
)
dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py
0 → 100644
浏览文件 @
1438e760
import
os
import
gym
from
tensorboardX
import
SummaryWriter
from
ding.config
import
compile_config
from
ding.worker
import
BaseLearner
,
SampleSerialCollector
,
InteractionSerialEvaluator
,
DequeBufferWrapper
from
ding.envs
import
BaseEnvManager
,
DingEnvWrapper
from
ding.policy
import
DQNPolicy
from
ding.model
import
DQN
from
ding.utils
import
set_pkg_seed
from
ding.rl_utils
import
get_epsilon_greedy_fn
from
dizoo.classic_control.cartpole.config.cartpole_dqn_config
import
cartpole_dqn_config
# Get DI-engine form env class
def
wrapped_cartpole_env
():
return
DingEnvWrapper
(
gym
.
make
(
'CartPole-v0'
))
def
main
(
cfg
,
seed
=
0
):
cfg
=
compile_config
(
cfg
,
BaseEnvManager
,
DQNPolicy
,
BaseLearner
,
SampleSerialCollector
,
InteractionSerialEvaluator
,
DequeBufferWrapper
,
save_cfg
=
True
)
collector_env_num
,
evaluator_env_num
=
cfg
.
env
.
collector_env_num
,
cfg
.
env
.
evaluator_env_num
collector_env
=
BaseEnvManager
(
env_fn
=
[
wrapped_cartpole_env
for
_
in
range
(
collector_env_num
)],
cfg
=
cfg
.
env
.
manager
)
evaluator_env
=
BaseEnvManager
(
env_fn
=
[
wrapped_cartpole_env
for
_
in
range
(
evaluator_env_num
)],
cfg
=
cfg
.
env
.
manager
)
# Set random seed for all package and instance
collector_env
.
seed
(
seed
)
evaluator_env
.
seed
(
seed
,
dynamic_seed
=
False
)
set_pkg_seed
(
seed
,
use_cuda
=
cfg
.
policy
.
cuda
)
# Set up RL Policy
model
=
DQN
(
**
cfg
.
policy
.
model
)
policy
=
DQNPolicy
(
cfg
.
policy
,
model
=
model
)
# Set up collection, training and evaluation utilities
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./{}/log/'
.
format
(
cfg
.
exp_name
),
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
collector
=
SampleSerialCollector
(
cfg
.
policy
.
collect
.
collector
,
collector_env
,
policy
.
collect_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
evaluator
=
InteractionSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
replay_buffer
=
DequeBufferWrapper
(
cfg
.
policy
.
other
.
replay_buffer
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
# Set up other modules, etc. epsilon greedy
eps_cfg
=
cfg
.
policy
.
other
.
eps
epsilon_greedy
=
get_epsilon_greedy_fn
(
eps_cfg
.
start
,
eps_cfg
.
end
,
eps_cfg
.
decay
,
eps_cfg
.
type
)
# Training & Evaluation loop
while
True
:
# Evaluating at the beginning and with specific frequency
if
evaluator
.
should_eval
(
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
learner
.
save_checkpoint
,
learner
.
train_iter
,
collector
.
envstep
)
if
stop
:
break
# Update other modules
eps
=
epsilon_greedy
(
collector
.
envstep
)
# Sampling data from environments
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
,
policy_kwargs
=
{
'eps'
:
eps
})
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
# Training
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
train_data
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
if
train_data
is
None
:
break
learner
.
train
(
train_data
,
collector
.
envstep
)
if
__name__
==
"__main__"
:
main
(
cartpole_dqn_config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录