Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
8dae34cb
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
60
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8dae34cb
编写于
11月 25, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(nyz): add basic new sil entry(ci skip)
上级
2f079e02
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
125 addition
and
64 deletion
+125
-64
ding/entry/serial_entry_sil.py
ding/entry/serial_entry_sil.py
+43
-35
ding/model/wrapper/model_wrappers.py
ding/model/wrapper/model_wrappers.py
+1
-1
ding/policy/a2c.py
ding/policy/a2c.py
+2
-2
ding/policy/sil.py
ding/policy/sil.py
+7
-5
dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py
dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py
+54
-0
dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py
dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py
+18
-21
未找到文件。
ding/entry/serial_entry_sil.py
浏览文件 @
8dae34cb
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
tensorboardX
import
SummaryWriter
from
ding.envs
import
get_vec_env_setting
,
create_env_manager
from
ding.worker
import
BaseLearner
,
SampleSerialCollector
,
InteractionSerialEvaluator
,
BaseSerialCommander
,
create_buffer
,
\
from
ding.worker
import
BaseLearner
,
InteractionSerialEvaluator
,
BaseSerialCommander
,
create_buffer
,
\
create_serial_collector
from
ding.config
import
read_config
,
compile_config
from
ding.policy
import
create_policy
,
PolicyFactory
,
create_sil
...
...
@@ -58,7 +58,12 @@ def serial_pipeline_sil(
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./{}/log/'
.
format
(
cfg
.
exp_name
),
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
base_learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
,
instance_name
=
'base_learner'
)
sil_learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
,
instance_name
=
'sil_learner'
)
collector
=
create_serial_collector
(
cfg
.
policy
.
collect
.
collector
,
env
=
collector_env
,
...
...
@@ -71,64 +76,67 @@ def serial_pipeline_sil(
)
replay_buffer
=
create_buffer
(
cfg
.
policy
.
other
.
replay_buffer
,
tb_logger
=
tb_logger
,
exp_name
=
cfg
.
exp_name
)
commander
=
BaseSerialCommander
(
cfg
.
policy
.
other
.
commander
,
learner
,
collector
,
evaluator
,
replay_buffer
,
policy
.
command_mode
cfg
.
policy
.
other
.
commander
,
base_
learner
,
collector
,
evaluator
,
replay_buffer
,
policy
.
command_mode
)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner
.
call_hook
(
'before_run'
)
new_ptr
=
old_ptr
=
0
base_
learner
.
call_hook
(
'before_run'
)
# Accumulate plenty of data at the beginning of training.
if
cfg
.
policy
.
get
(
'random_collect_size'
,
0
)
>
0
:
action_space
=
collector_env
.
env_info
().
act_space
random_policy
=
PolicyFactory
.
get_random_policy
(
policy
.
collect_mode
,
action_space
=
action_space
)
collector
.
reset_policy
(
policy
.
collect_mode
)
if
cfg
.
policy
.
get
(
'transition_with_policy_data'
,
False
):
collector
.
reset_policy
(
policy
.
collect_mode
)
else
:
action_space
=
collector_env
.
env_info
().
act_space
random_policy
=
PolicyFactory
.
get_random_policy
(
policy
.
collect_mode
,
action_space
=
action_space
)
collector
.
reset_policy
(
random_policy
)
collect_kwargs
=
commander
.
step
()
new_data
=
collector
.
collect
(
n_
episod
e
=
cfg
.
policy
.
random_collect_size
,
policy_kwargs
=
collect_kwargs
)
new_data
=
collector
.
collect
(
n_
sampl
e
=
cfg
.
policy
.
random_collect_size
,
policy_kwargs
=
collect_kwargs
)
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
0
)
collector
.
reset_policy
(
policy
.
collect_mode
)
new_ptr
+=
len
(
new_data
)
for
_
in
range
(
max_iterations
):
collect_kwargs
=
commander
.
step
()
# Evaluate policy performance
if
evaluator
.
should_eval
(
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
learner
.
save_checkpoint
,
learner
.
train_iter
,
collector
.
envstep
)
if
evaluator
.
should_eval
(
base_
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
base_learner
.
save_checkpoint
,
base_
learner
.
train_iter
,
collector
.
envstep
)
if
stop
:
break
# Collect data by default config n_sample/n_episode
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
,
policy_kwargs
=
collect_kwargs
)
new_data
=
collector
.
collect
(
train_iter
=
base_
learner
.
train_iter
,
policy_kwargs
=
collect_kwargs
)
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
new_ptr
+=
len
(
new_data
)
# Learn policy from collected data
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
# Learner will train ``update_per_collect`` times in one iteration.
if
cfg
.
policy
.
on_policy
:
train_data_base_policy
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
,
sample_range
=
slice
(
old_ptr
-
new_ptr
,
None
)
)
else
:
if
cfg
.
policy
.
on_policy
:
train_data_base_policy
=
replay_buffer
.
sample
(
base_learner
.
policy
.
get_attribute
(
'batch_size'
),
base_learner
.
train_iter
,
sample_range
=
slice
(
-
len
(
new_data
),
None
)
)
base_learner
.
train
(
train_data_base_policy
)
else
:
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
# Learner will train ``update_per_collect`` times in one iteration.
train_data_base_policy
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
base_learner
.
policy
.
get_attribute
(
'batch_size'
),
base_
learner
.
train_iter
)
train_data_sil
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
if
train_data_base_policy
is
None
or
train_data_sil
is
None
:
base_learner
.
train
(
train_data_base_policy
)
if
base_learner
.
policy
.
get_attribute
(
'priority'
):
replay_buffer
.
update
(
base_learner
.
priority_info
)
for
i
in
range
(
cfg
.
policy
.
other
.
sil
.
update_per_collect
):
train_data_sil
=
replay_buffer
.
sample
(
cfg
.
policy
.
other
.
sil
.
n_episode_per_train
,
sil_learner
.
train_iter
,
groupby
=
'episode'
)
train_data_sil
=
policy
.
process_sil_data
(
train_data_sil
)
if
train_data_sil
is
None
:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging
.
warning
(
"Replay buffer's data can only train for {} steps. "
.
format
(
i
)
+
"Replay buffer's data can only train for
sil
{} steps. "
.
format
(
i
)
+
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
learner
.
train
({
'base_policy'
:
train_data_base_policy
,
'sil'
:
train_data_sil
},
collector
.
envstep
)
if
learner
.
policy
.
get_attribute
(
'priority'
):
replay_buffer
.
update
(
learner
.
priority_info
)
if
cfg
.
policy
.
on_policy
:
# On-policy algorithm must clear the replay buffer.
# replay_buffer.clear()
old_ptr
=
new_ptr
sil_learner
.
train
(
train_data_sil
)
# Learner's after_run hook.
learner
.
call_hook
(
'after_run'
)
base_
learner
.
call_hook
(
'after_run'
)
return
policy
ding/model/wrapper/model_wrappers.py
浏览文件 @
8dae34cb
...
...
@@ -320,7 +320,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
def
forward
(
self
,
*
args
,
**
kwargs
):
eps
=
kwargs
.
pop
(
'eps'
)
alpha
=
kwargs
.
pop
(
'alpha'
)
alpha
=
kwargs
.
pop
(
'alpha'
,
1
)
output
=
self
.
_model
.
forward
(
*
args
,
**
kwargs
)
assert
isinstance
(
output
,
dict
),
"model output must be dict, but find {}"
.
format
(
type
(
output
))
logit
=
output
[
'logit'
]
...
...
ding/policy/a2c.py
浏览文件 @
8dae34cb
...
...
@@ -170,7 +170,7 @@ class A2CPolicy(Policy):
self
.
_gamma
=
self
.
_cfg
.
collect
.
discount_factor
self
.
_gae_lambda
=
self
.
_cfg
.
collect
.
gae_lambda
def
_forward_collect
(
self
,
data
:
dict
)
->
dict
:
def
_forward_collect
(
self
,
data
:
dict
,
**
kwargs
)
->
dict
:
r
"""
Overview:
Forward function of collect mode.
...
...
@@ -188,7 +188,7 @@ class A2CPolicy(Policy):
data
=
to_device
(
data
,
self
.
_device
)
self
.
_collect_model
.
eval
()
with
torch
.
no_grad
():
output
=
self
.
_collect_model
.
forward
(
data
,
mode
=
'compute_actor_critic'
)
output
=
self
.
_collect_model
.
forward
(
data
,
mode
=
'compute_actor_critic'
,
**
kwargs
)
if
self
.
_cuda
:
output
=
to_device
(
output
,
'cpu'
)
output
=
default_decollate
(
output
)
...
...
ding/policy/sil.py
浏览文件 @
8dae34cb
...
...
@@ -17,12 +17,14 @@ def create_sil(policy: Policy, cfg):
return
sil_policy
class
SIL
(
Policy
):
class
SIL
Policy
(
Policy
):
r
"""
Overview:
Policy class of SIL algorithm.
"""
sil_config
=
dict
(
config
=
dict
(
update_per_collect
=
10
,
n_episode_per_train
=
4
,
value_weight
=
0.5
,
learning_rate
=
0.001
,
betas
=
(
0.9
,
0.999
),
...
...
@@ -32,8 +34,8 @@ class SIL(Policy):
def
__init__
(
self
,
policy
:
Policy
,
cfg
):
self
.
base_policy
=
policy
self
.
_model
=
policy
.
_model
cfg
.
policy
.
other
.
sil
=
deep_merge_dicts
(
self
.
sil_
config
,
cfg
.
policy
.
other
.
sil
)
super
(
SIL
,
self
).
__init__
(
cfg
.
policy
,
model
=
policy
.
_model
,
enable_field
=
policy
.
_enable_field
)
cfg
.
policy
.
other
.
sil
=
deep_merge_dicts
(
self
.
config
,
cfg
.
policy
.
other
.
sil
)
super
(
SIL
Policy
,
self
).
__init__
(
cfg
.
policy
,
model
=
policy
.
_model
,
enable_field
=
policy
.
_enable_field
)
def
_init_learn
(
self
)
->
None
:
r
"""
...
...
@@ -184,5 +186,5 @@ class SIL(Policy):
]
class
SILCommand
(
SIL
,
DummyCommandModePolicy
):
class
SILCommand
(
SIL
Policy
,
DummyCommandModePolicy
):
pass
dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py
0 → 100644
浏览文件 @
8dae34cb
from
ding.entry.serial_entry_onpolicy
import
serial_pipeline_onpolicy
from
easydict
import
EasyDict
lunarlander_a2c_config
=
dict
(
exp_name
=
'lunarlander_a2c_seed0'
,
env
=
dict
(
collector_env_num
=
4
,
evaluator_env_num
=
5
,
n_evaluator_episode
=
5
,
stop_value
=
200
,
),
policy
=
dict
(
cuda
=
False
,
model
=
dict
(
obs_shape
=
8
,
action_shape
=
4
,
encoder_hidden_size_list
=
[
128
,
64
],
share_encoder
=
False
,
),
learn
=
dict
(
batch_size
=
64
,
# (bool) Whether to normalize advantage. Default to False.
adv_norm
=
False
,
learning_rate
=
0.001
,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight
=
0.1
,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight
=
0.00001
,
),
collect
=
dict
(
# (int) collect n_sample data, train model n_iteration times
n_sample
=
64
,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda
=
0.95
,
discount_factor
=
0.995
,
),
),
)
lunarlander_a2c_config
=
EasyDict
(
lunarlander_a2c_config
)
main_config
=
lunarlander_a2c_config
lunarlander_a2c_create_config
=
dict
(
env
=
dict
(
type
=
'lunarlander'
,
import_names
=
[
'dizoo.box2d.lunarlander.envs.lunarlander_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
),
policy
=
dict
(
type
=
'a2c'
),
)
lunarlander_a2c_create_config
=
EasyDict
(
lunarlander_a2c_create_config
)
create_config
=
lunarlander_a2c_create_config
if
__name__
==
'__main__'
:
serial_pipeline_onpolicy
((
main_config
,
create_config
),
seed
=
0
)
dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py
浏览文件 @
8dae34cb
...
...
@@ -2,47 +2,45 @@ from ding.entry.serial_entry_sil import serial_pipeline_sil
from
easydict
import
EasyDict
lunarlander_a2c_config
=
dict
(
exp_name
=
'lunarlander_a2c'
,
exp_name
=
'lunarlander_a2c
_sil_seed0
'
,
env
=
dict
(
collector_env_num
=
8
,
collector_env_num
=
4
,
evaluator_env_num
=
5
,
n_evaluator_episode
=
5
,
stop_value
=
195
,
stop_value
=
200
,
),
policy
=
dict
(
on_policy
=
True
,
cuda
=
False
,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
model
=
dict
(
obs_shape
=
8
,
action_shape
=
4
,
encoder_hidden_size_list
=
[
512
,
64
],
encoder_hidden_size_list
=
[
128
,
64
],
share_encoder
=
False
,
),
learn
=
dict
(
batch_size
=
64
,
# (bool) Whether to normalize advantage. Default to False.
unroll_len
=
1
,
normalize_advantage
=
False
,
adv_norm
=
False
,
learning_rate
=
0.001
,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight
=
0.
5
,
value_weight
=
0.
1
,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight
=
1.0
,
entropy_weight
=
0.00001
,
),
collect
=
dict
(
collector
=
dict
(
type
=
'episode'
,
get_train_sample
=
True
,
),
# (int) collect n_sample data, train model n_iteration times
n_
episode
=
8
,
n_
sample
=
64
,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda
=
0.95
,
discount_factor
=
0.995
,
),
other
=
dict
(
sil
=
dict
(
value_weight
=
0.5
,
learning_rate
=
0.0001
,
),
replay_buffer
=
dict
(
replay_buffer_size
=
200000
,
)),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
),
sil
=
dict
(
update_per_collect
=
10
,
n_episode_per_train
=
8
,
),
)
),
)
lunarlander_a2c_config
=
EasyDict
(
lunarlander_a2c_config
)
...
...
@@ -55,11 +53,10 @@ lunarlander_a2c_create_config = dict(
),
env_manager
=
dict
(
type
=
'subprocess'
),
policy
=
dict
(
type
=
'a2c'
),
replay_buffer
=
dict
(
type
=
'deque'
),
)
lunarlander_a2c_create_config
=
EasyDict
(
lunarlander_a2c_create_config
)
create_config
=
lunarlander_a2c_create_config
if
__name__
==
'__main__'
:
from
ding.entry
import
serial_entry_sil
serial_pipeline_sil
((
main_config
,
create_config
),
seed
=
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录