Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
7e51de4f
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7e51de4f
编写于
11月 22, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): simplify onppo with traj_flag
上级
0b46dd24
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
15 addition
and
149 deletion
+15
-149
ding/entry/serial_entry_onpolicy.py
ding/entry/serial_entry_onpolicy.py
+2
-28
ding/policy/ppo.py
ding/policy/ppo.py
+10
-115
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
+1
-4
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
+2
-2
未找到文件。
ding/entry/serial_entry_onpolicy.py
浏览文件 @
7e51de4f
...
...
@@ -69,9 +69,8 @@ def serial_pipeline_onpolicy(
evaluator
=
InteractionSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
replay_buffer
=
create_buffer
(
cfg
.
policy
.
other
.
replay_buffer
,
tb_logger
=
tb_logger
,
exp_name
=
cfg
.
exp_name
)
commander
=
BaseSerialCommander
(
cfg
.
policy
.
other
.
commander
,
learner
,
collector
,
evaluator
,
replay_buffer
,
policy
.
command_mode
cfg
.
policy
.
other
.
commander
,
learner
,
collector
,
evaluator
,
None
,
policy
.
command_mode
)
# ==========
...
...
@@ -80,15 +79,6 @@ def serial_pipeline_onpolicy(
# Learner's before_run hook.
learner
.
call_hook
(
'before_run'
)
# Accumulate plenty of data at the beginning of training.
if
cfg
.
policy
.
get
(
'random_collect_size'
,
0
)
>
0
:
action_space
=
collector_env
.
env_info
().
act_space
random_policy
=
PolicyFactory
.
get_random_policy
(
policy
.
collect_mode
,
action_space
=
action_space
)
collector
.
reset_policy
(
random_policy
)
collect_kwargs
=
commander
.
step
()
new_data
=
collector
.
collect
(
n_sample
=
cfg
.
policy
.
random_collect_size
,
policy_kwargs
=
collect_kwargs
)
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
0
)
collector
.
reset_policy
(
policy
.
collect_mode
)
for
_
in
range
(
max_iterations
):
collect_kwargs
=
commander
.
step
()
# Evaluate policy performance
...
...
@@ -100,23 +90,7 @@ def serial_pipeline_onpolicy(
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
,
policy_kwargs
=
collect_kwargs
)
# Learn policy from collected data
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
# update_per_collect=1, for onppo
# Learner will train ``update_per_collect`` times in one iteration.
train_data
=
new_data
if
train_data
is
None
:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging
.
warning
(
"Replay buffer's data can only train for {} steps. "
.
format
(
i
)
+
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
learner
.
train
(
train_data
,
collector
.
envstep
)
if
learner
.
policy
.
get_attribute
(
'priority'
):
replay_buffer
.
update
(
learner
.
priority_info
)
if
cfg
.
policy
.
on_policy
:
# On-policy algorithm must clear the replay buffer.
replay_buffer
.
clear
()
learner
.
train
(
new_data
,
collector
.
envstep
)
# Learner's after_run hook.
learner
.
call_hook
(
'after_run'
)
...
...
ding/policy/ppo.py
浏览文件 @
7e51de4f
...
...
@@ -17,91 +17,6 @@ from .common_utils import default_preprocess_learn
from
ding.utils
import
dicts_to_lists
,
lists_to_dicts
def
compute_adv
(
data
,
last_value
,
cfg
):
# last_value could be the real last value of the last timestep in the whole traj,
# or the next_value sequence for each timesteps.
data
=
get_gae
(
data
,
last_value
,
gamma
=
cfg
.
collect
.
discount_factor
,
gae_lambda
=
cfg
.
collect
.
gae_lambda
,
cuda
=
False
)
# data: list (T timestep, 1 batch) [['value':,'reward':,'adv':], ...,]
return
get_nstep_return_data
(
data
,
cfg
.
nstep
)
if
cfg
.
nstep_return
else
get_train_sample
(
data
,
cfg
.
collect
.
unroll_len
)
def
dict_data_split_traj_and_compute_adv
(
data
,
next_value
,
cfg
):
# because the get_gae function need input the traj data in the same episode not different episodes,
# so we should split the data into traj according to the key 'done' and 'traj_flag' if have, and
# the max_traj_length <cfg.collect.n_sample // cfg.collect.collector_env_num>
# data shape: dict of torch.FloatTensor of thansitions
# {'obs':[torch.FloatTensor], ...,'reward':[torch.FloatTensor],...}
# traj means consequent transitions in one episode,it may be the whole episode or truncated episode,
# or consequent part of one episode, because the restrict of max_traj_len.
processed_data
=
[]
start_index
=
0
timesteps
=
0
for
i
in
range
(
data
[
'reward'
].
shape
[
0
]):
timesteps
+=
1
traj_data
=
[]
if
'traj_flag'
in
data
.
keys
():
# for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
traj_flag
=
data
[
'traj_flag'
][
i
]
else
:
traj_flag
=
data
[
'done'
][
i
]
if
traj_flag
:
# data['done'][i]: torch.tensor(1.) or True
for
k
in
range
(
start_index
,
i
+
1
):
# transform to shape like this:
# traj_data.append( {'value':data['value'][k] ,'reward':data['reward'][k] ,'adv':data['adv'][k] } )
# if discrete action: traj_data.append({key: data[key][k] for key in data.keys()})
# if continuous action: data['logit'] list(torch.tensor(3200,6)); data['weight'] list
traj_data
.
append
(
{
key
:
[
data
[
key
][
logit_index
][
k
]
for
logit_index
in
range
(
len
(
data
[
key
]))]
if
isinstance
(
data
[
key
],
list
)
and
key
==
'logit'
else
data
[
key
][
k
]
for
key
in
data
.
keys
()
}
)
if
data
[
'done'
][
i
]:
# if done
next_value
[
i
]
=
torch
.
zeros
(
1
)[
0
].
to
(
data
[
'obs'
][
0
].
device
)
processed_data
.
extend
(
traj_data
)
start_index
=
i
+
1
timesteps
=
0
continue
if
timesteps
==
cfg
.
collect
.
n_sample
//
cfg
.
collect
.
collector_env_num
:
# equals self._traj_len, e.g. 64
for
k
in
range
(
start_index
,
i
+
1
):
traj_data
.
append
(
{
key
:
[
data
[
key
][
logit_index
][
k
]
for
logit_index
in
range
(
len
(
data
[
key
]))]
if
isinstance
(
data
[
key
],
list
)
and
key
==
'logit'
else
data
[
key
][
k
]
for
key
in
data
.
keys
()
}
)
# traj_data = compute_adv(traj_data, next_value[i], cfg)
if
data
[
'done'
][
i
]:
# if done
next_value
[
i
]
=
torch
.
zeros
(
1
)[
0
].
to
(
data
[
'obs'
][
0
].
device
)
processed_data
.
extend
(
traj_data
)
start_index
=
i
+
1
timesteps
=
0
continue
remaining_traj_data
=
[]
for
k
in
range
(
start_index
,
i
+
1
):
remaining_traj_data
.
append
(
{
key
:
[
data
[
key
][
logit_index
][
k
]
for
logit_index
in
range
(
len
(
data
[
key
]))]
if
isinstance
(
data
[
key
],
list
)
and
key
==
'logit'
else
data
[
key
][
k
]
for
key
in
data
.
keys
()
}
)
if
data
[
'done'
][
i
]:
# if done
next_value
[
i
]
=
torch
.
zeros
(
1
)[
0
].
to
(
data
[
'obs'
][
0
].
device
)
# add the remaining data, return shape list of dict
data
=
processed_data
+
remaining_traj_data
return
compute_adv
(
data
,
next_value
,
cfg
)
@
POLICY_REGISTRY
.
register
(
'ppo'
)
class
PPOPolicy
(
Policy
):
r
"""
...
...
@@ -232,11 +147,6 @@ class PPOPolicy(Policy):
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
adv_abs_max, approx_kl, clipfrac
"""
# for transition in data:
# # for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
# if 'traj_flag' not in transition.keys():
# transition['traj_flag'] = copy.deepcopy(transition['done'])
data
=
default_preprocess_learn
(
data
,
ignore_done
=
self
.
_cfg
.
learn
.
ignore_done
,
use_nstep
=
False
)
if
self
.
_cuda
:
data
=
to_device
(
data
,
self
.
_device
)
...
...
@@ -257,39 +167,26 @@ class PPOPolicy(Policy):
for
epoch
in
range
(
self
.
_cfg
.
learn
.
epoch_per_collect
):
if
self
.
_recompute_adv
:
# new v network compute new value
with
torch
.
no_grad
():
# obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value
=
self
.
_learn_model
.
forward
(
data
[
'obs'
],
mode
=
'compute_critic'
)[
'value'
]
next_value
=
self
.
_learn_model
.
forward
(
data
[
'next_obs'
],
mode
=
'compute_critic'
)[
'value'
]
if
self
.
_value_norm
:
value
*=
self
.
_running_mean_std
.
std
next_value
*=
self
.
_running_mean_std
.
std
data
[
'value'
]
=
value
data
[
'weight'
]
=
[
None
for
i
in
range
(
data
[
'reward'
].
shape
[
0
])]
processed_data
=
dict_data_split_traj_and_compute_adv
(
data
,
next_value
.
to
(
self
.
_device
),
self
.
_cfg
)
compute_adv_data
=
gae_data
(
value
,
next_value
,
data
[
'reward'
],
data
[
'done'
],
data
[
'traj_flag'
])
data
[
'adv'
]
=
gae
(
compute_adv_data
,
self
.
_gamma
,
self
.
_gae_lambda
)
processed_data
=
lists_to_dicts
(
processed_data
)
for
k
,
v
in
processed_data
.
items
():
if
isinstance
(
v
[
0
],
torch
.
Tensor
):
processed_data
[
k
]
=
torch
.
stack
(
v
,
dim
=
0
)
processed_data
[
'weight'
]
=
None
unnormalized_returns
=
processed_data
[
'value'
]
+
processed_data
[
'adv'
]
unnormalized_returns
=
value
+
data
[
'adv'
]
if
self
.
_value_norm
:
processed_data
[
'value'
]
=
processed_
data
[
'value'
]
/
self
.
_running_mean_std
.
std
processed_
data
[
'return'
]
=
unnormalized_returns
/
self
.
_running_mean_std
.
std
data
[
'value'
]
=
data
[
'value'
]
/
self
.
_running_mean_std
.
std
data
[
'return'
]
=
unnormalized_returns
/
self
.
_running_mean_std
.
std
self
.
_running_mean_std
.
update
(
unnormalized_returns
.
cpu
().
numpy
())
else
:
processed_data
[
'value'
]
=
processed_data
[
'value'
]
processed_data
[
'return'
]
=
unnormalized_returns
else
:
processed_data
=
data
data
[
'value'
]
=
data
[
'value'
]
data
[
'return'
]
=
unnormalized_returns
for
batch
in
split_data_generator
(
processed_
data
,
self
.
_cfg
.
learn
.
batch_size
,
shuffle
=
True
):
for
batch
in
split_data_generator
(
data
,
self
.
_cfg
.
learn
.
batch_size
,
shuffle
=
True
):
output
=
self
.
_learn_model
.
forward
(
batch
[
'obs'
],
mode
=
'compute_actor_critic'
)
adv
=
batch
[
'adv'
]
if
self
.
_adv_norm
:
...
...
@@ -429,11 +326,9 @@ class PPOPolicy(Policy):
"""
data
=
to_device
(
data
,
self
.
_device
)
for
transition
in
data
:
# for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
if
'traj_flag'
not
in
transition
.
keys
():
transition
[
'traj_flag'
]
=
copy
.
deepcopy
(
transition
[
'done'
])
transition
[
'traj_flag'
]
=
copy
.
deepcopy
(
transition
[
'done'
])
data
[
-
1
][
'traj_flag'
]
=
True
# adder is defined in _init_collect
if
self
.
_cfg
.
learn
.
ignore_done
:
data
[
-
1
][
'done'
]
=
False
...
...
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
浏览文件 @
7e51de4f
...
...
@@ -34,10 +34,7 @@ cartpole_ppo_config = dict(
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
1000
,
cfg_type
=
'InteractionSerialEvaluatorDict'
,
stop_value
=
195
,
n_episode
=
5
,
eval_freq
=
100
,
),
),
),
...
...
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
浏览文件 @
7e51de4f
...
...
@@ -11,7 +11,7 @@ pendulum_ppo_config = dict(
policy
=
dict
(
cuda
=
False
,
continuous
=
True
,
recompute_adv
=
Fals
e
,
recompute_adv
=
Tru
e
,
model
=
dict
(
obs_shape
=
3
,
action_shape
=
1
,
...
...
@@ -31,7 +31,7 @@ pendulum_ppo_config = dict(
clip_ratio
=
0.2
,
adv_norm
=
False
,
value_norm
=
True
,
ignore_done
=
Fals
e
,
ignore_done
=
Tru
e
,
),
collect
=
dict
(
n_sample
=
200
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录