Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
4e833da2
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4e833da2
编写于
7月 29, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(nyz): polish cartpole ppo demo and related unittest
上级
3243c92d
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
154 addition
and
52 deletion
+154
-52
ding/config/config.py
ding/config/config.py
+2
-2
ding/entry/tests/test_application_entry.py
ding/entry/tests/test_application_entry.py
+8
-6
ding/entry/tests/test_serial_entry.py
ding/entry/tests/test_serial_entry.py
+2
-1
ding/entry/tests/test_serial_entry_algo.py
ding/entry/tests/test_serial_entry_algo.py
+3
-2
ding/entry/tests/test_serial_entry_il.py
ding/entry/tests/test_serial_entry_il.py
+7
-7
ding/entry/tests/test_serial_entry_reward_model.py
ding/entry/tests/test_serial_entry_reward_model.py
+3
-3
ding/policy/command_mode_policy_instance.py
ding/policy/command_mode_policy_instance.py
+1
-1
ding/policy/ppo.py
ding/policy/ppo.py
+3
-5
dizoo/classic_control/cartpole/config/__init__.py
dizoo/classic_control/cartpole/config/__init__.py
+1
-0
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
+4
-2
dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py
..._control/cartpole/config/cartpole_ppo_offpolicy_config.py
+49
-0
dizoo/classic_control/cartpole/config/cartpole_ppo_rnd_config.py
...lassic_control/cartpole/config/cartpole_ppo_rnd_config.py
+2
-1
dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py
dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py
+3
-17
dizoo/classic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py
...sic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py
+63
-0
dizoo/classic_control/cartpole/entry/cartpole_ppo_rnd_main.py
...o/classic_control/cartpole/entry/cartpole_ppo_rnd_main.py
+3
-5
未找到文件。
ding/config/config.py
浏览文件 @
4e833da2
...
...
@@ -13,7 +13,7 @@ from easydict import EasyDict
from
ding.utils
import
deep_merge_dicts
from
ding.envs
import
get_env_cls
,
get_env_manager_cls
from
ding.policy
import
get_policy_cls
from
ding.worker
import
BaseLearner
,
BaseSerialEvaluator
,
BaseSerialCommander
,
Coordinator
,
\
from
ding.worker
import
BaseLearner
,
BaseSerialEvaluator
,
BaseSerialCommander
,
Coordinator
,
AdvancedReplayBuffer
,
\
get_parallel_commander_cls
,
get_parallel_collector_cls
,
get_buffer_cls
,
get_serial_collector_cls
from
ding.reward_model
import
get_reward_model_cls
from
.utils
import
parallel_transform
,
parallel_transform_slurm
,
parallel_transform_k8s
,
save_config_formatted
...
...
@@ -309,7 +309,7 @@ def compile_config(
learner
:
type
=
BaseLearner
,
collector
:
type
=
None
,
evaluator
:
type
=
BaseSerialEvaluator
,
buffer
:
type
=
None
,
buffer
:
type
=
AdvancedReplayBuffer
,
env
:
type
=
None
,
reward_model
:
type
=
None
,
seed
:
int
=
0
,
...
...
ding/entry/tests/test_application_entry.py
浏览文件 @
4e833da2
...
...
@@ -3,7 +3,7 @@ import pytest
import
os
import
pickle
from
dizoo.classic_control.cartpole.config.cartpole_ppo_
config
import
cartpole_ppo_config
,
cartpole_ppo_create_config
from
dizoo.classic_control.cartpole.config.cartpole_ppo_
offpolicy_config
import
cartpole_ppo_offpolicy_config
,
cartpole_ppo_offpolicy_create_config
# noqa
from
dizoo.classic_control.cartpole.envs
import
CartPoleEnv
from
ding.entry
import
serial_pipeline
,
eval
,
collect_demo_data
from
ding.config
import
compile_config
...
...
@@ -11,7 +11,7 @@ from ding.config import compile_config
@
pytest
.
fixture
(
scope
=
'module'
)
def
setup_state_dict
():
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
try
:
policy
=
serial_pipeline
(
config
,
seed
=
0
)
except
Exception
:
...
...
@@ -27,12 +27,14 @@ def setup_state_dict():
class
TestApplication
:
def
test_eval
(
self
,
setup_state_dict
):
cfg_for_stop_value
=
compile_config
(
cartpole_ppo_config
,
auto
=
True
,
create_cfg
=
cartpole_ppo_create_config
)
cfg_for_stop_value
=
compile_config
(
cartpole_ppo_offpolicy_config
,
auto
=
True
,
create_cfg
=
cartpole_ppo_offpolicy_create_config
)
stop_value
=
cfg_for_stop_value
.
env
.
stop_value
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
eval_reward
=
eval
(
config
,
seed
=
0
,
state_dict
=
setup_state_dict
[
'eval'
])
assert
eval_reward
>=
stop_value
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
eval_reward
=
eval
(
config
,
seed
=
0
,
...
...
@@ -42,7 +44,7 @@ class TestApplication:
assert
eval_reward
>=
stop_value
def
test_collect_demo_data
(
self
,
setup_state_dict
):
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
collect_count
=
16
expert_data_path
=
'./expert.data'
collect_demo_data
(
...
...
ding/entry/tests/test_serial_entry.py
浏览文件 @
4e833da2
...
...
@@ -15,6 +15,7 @@ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole
from
dizoo.classic_control.cartpole.config.cartpole_sqn_config
import
cartpole_sqn_config
,
cartpole_sqn_create_config
# noqa
from
dizoo.classic_control.cartpole.config.cartpole_ppg_config
import
cartpole_ppg_config
,
cartpole_ppg_create_config
# noqa
from
dizoo.classic_control.cartpole.entry.cartpole_ppg_main
import
main
as
ppg_main
from
dizoo.classic_control.cartpole.entry.cartpole_ppo_main
import
main
as
ppo_main
from
dizoo.classic_control.cartpole.config.cartpole_r2d2_config
import
cartpole_r2d2_config
,
cartpole_r2d2_create_config
# noqa
from
dizoo.classic_control.pendulum.config
import
pendulum_ddpg_config
,
pendulum_ddpg_create_config
from
dizoo.classic_control.pendulum.config
import
pendulum_td3_config
,
pendulum_td3_create_config
...
...
@@ -116,7 +117,7 @@ def test_ppo():
config
=
[
deepcopy
(
cartpole_ppo_config
),
deepcopy
(
cartpole_ppo_create_config
)]
config
[
0
].
policy
.
learn
.
update_per_collect
=
1
try
:
serial_pipeline
(
config
,
seed
=
0
,
max_iterations
=
1
)
ppo_main
(
config
[
0
]
,
seed
=
0
,
max_iterations
=
1
)
except
Exception
:
assert
False
,
"pipeline fail"
...
...
ding/entry/tests/test_serial_entry_algo.py
浏览文件 @
4e833da2
...
...
@@ -15,6 +15,7 @@ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole
from
dizoo.classic_control.cartpole.config.cartpole_sqn_config
import
cartpole_sqn_config
,
cartpole_sqn_create_config
# noqa
from
dizoo.classic_control.cartpole.config.cartpole_ppg_config
import
cartpole_ppg_config
,
cartpole_ppg_create_config
# noqa
from
dizoo.classic_control.cartpole.entry.cartpole_ppg_main
import
main
as
ppg_main
from
dizoo.classic_control.cartpole.entry.cartpole_ppo_main
import
main
as
ppo_main
from
dizoo.classic_control.cartpole.config.cartpole_r2d2_config
import
cartpole_r2d2_config
,
cartpole_r2d2_create_config
# noqa
from
dizoo.classic_control.pendulum.config
import
pendulum_ddpg_config
,
pendulum_ddpg_create_config
from
dizoo.classic_control.pendulum.config
import
pendulum_td3_config
,
pendulum_td3_create_config
...
...
@@ -90,7 +91,7 @@ def test_rainbow():
def
test_ppo
():
config
=
[
deepcopy
(
cartpole_ppo_config
),
deepcopy
(
cartpole_ppo_create_config
)]
try
:
serial_pipeline
(
config
,
seed
=
0
)
ppo_main
(
config
[
0
]
,
seed
=
0
)
except
Exception
:
assert
False
,
"pipeline fail"
with
open
(
"./algo_record.log"
,
"a+"
)
as
f
:
...
...
@@ -260,4 +261,4 @@ def test_qrdqn():
except
Exception
:
assert
False
,
"pipeline fail"
with
open
(
"./algo_record.log"
,
"a+"
)
as
f
:
f
.
write
(
"2
0. s
qn
\n
"
)
f
.
write
(
"2
1. qrd
qn
\n
"
)
ding/entry/tests/test_serial_entry_il.py
浏览文件 @
4e833da2
...
...
@@ -6,21 +6,21 @@ import torch
from
collections
import
namedtuple
import
os
from
dizoo.classic_control.cartpole.config
import
cartpole_ppo_config
,
cartpole_ppo_create_config
,
\
cartpole_dqn_config
,
cartpole_dqn_create_config
from
ding.torch_utils
import
Adam
,
to_device
from
ding.config
import
compile_config
from
ding.model
import
model_wrap
from
ding.rl_utils
import
get_train_sample
,
get_nstep_return_data
from
ding.entry
import
serial_pipeline_il
,
collect_demo_data
,
serial_pipeline
from
ding.policy
import
PPOPolicy
,
ILPolicy
from
ding.policy
import
PPO
Off
Policy
,
ILPolicy
from
ding.policy.common_utils
import
default_preprocess_learn
from
ding.utils
import
POLICY_REGISTRY
from
ding.utils.data
import
default_collate
,
default_decollate
from
dizoo.classic_control.cartpole.config
import
cartpole_dqn_config
,
cartpole_dqn_create_config
,
\
cartpole_ppo_offpolicy_config
,
cartpole_ppo_offpolicy_create_config
@
POLICY_REGISTRY
.
register
(
'ppo_il'
)
class
PPOILPolicy
(
PPOPolicy
):
class
PPOILPolicy
(
PPO
Off
Policy
):
def
_forward_learn
(
self
,
data
:
dict
)
->
dict
:
data
=
default_preprocess_learn
(
data
,
ignore_done
=
self
.
_cfg
.
learn
.
get
(
'ignore_done'
,
False
),
use_nstep
=
False
)
...
...
@@ -46,20 +46,20 @@ class PPOILPolicy(PPOPolicy):
@
pytest
.
mark
.
unittest
def
test_serial_pipeline_il_ppo
():
# train expert policy
train_config
=
[
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)]
train_config
=
[
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)]
expert_policy
=
serial_pipeline
(
train_config
,
seed
=
0
)
# collect expert demo data
collect_count
=
10000
expert_data_path
=
'expert_data_ppo.pkl'
state_dict
=
expert_policy
.
collect_mode
.
state_dict
()
collect_config
=
[
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)]
collect_config
=
[
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)]
collect_demo_data
(
collect_config
,
seed
=
0
,
state_dict
=
state_dict
,
expert_data_path
=
expert_data_path
,
collect_count
=
collect_count
)
# il training 1
il_config
=
[
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)]
il_config
=
[
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)]
il_config
[
0
].
policy
.
learn
.
train_epoch
=
10
il_config
[
0
].
policy
.
type
=
'ppo_il'
_
,
converge_stop_flag
=
serial_pipeline_il
(
il_config
,
seed
=
314
,
data_path
=
expert_data_path
)
...
...
ding/entry/tests/test_serial_entry_reward_model.py
浏览文件 @
4e833da2
...
...
@@ -5,7 +5,7 @@ from easydict import EasyDict
from
copy
import
deepcopy
from
dizoo.classic_control.cartpole.config.cartpole_dqn_config
import
cartpole_dqn_config
,
cartpole_dqn_create_config
from
dizoo.classic_control.cartpole.config.cartpole_ppo_
config
import
cartpole_ppo_config
,
cartpole_ppo_create_config
from
dizoo.classic_control.cartpole.config.cartpole_ppo_
offpolicy_config
import
cartpole_ppo_offpolicy_config
,
cartpole_ppo_offpolicy_create_config
# noqa
from
dizoo.classic_control.cartpole.config.cartpole_ppo_rnd_config
import
cartpole_ppo_rnd_config
,
cartpole_ppo_rnd_create_config
# noqa
from
ding.entry
import
serial_pipeline
,
collect_demo_data
,
serial_pipeline_reward_model
...
...
@@ -42,13 +42,13 @@ cfg = [
@
pytest
.
mark
.
parametrize
(
'reward_model_config'
,
cfg
)
def
test_irl
(
reward_model_config
):
reward_model_config
=
EasyDict
(
reward_model_config
)
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
expert_policy
=
serial_pipeline
(
config
,
seed
=
0
,
max_iterations
=
2
)
# collect expert demo data
collect_count
=
10000
expert_data_path
=
'expert_data.pkl'
state_dict
=
expert_policy
.
collect_mode
.
state_dict
()
config
=
deepcopy
(
cartpole_ppo_
config
),
deepcopy
(
cartpole_ppo
_create_config
)
config
=
deepcopy
(
cartpole_ppo_
offpolicy_config
),
deepcopy
(
cartpole_ppo_offpolicy
_create_config
)
collect_demo_data
(
config
,
seed
=
0
,
state_dict
=
state_dict
,
expert_data_path
=
expert_data_path
,
collect_count
=
collect_count
)
...
...
ding/policy/command_mode_policy_instance.py
浏览文件 @
4e833da2
...
...
@@ -111,7 +111,7 @@ class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
@
POLICY_REGISTRY
.
register
(
'ppo_offpolicy_command'
)
class
PPOCommandModePolicy
(
PPOOffPolicy
,
DummyCommandModePolicy
):
class
PPO
Off
CommandModePolicy
(
PPOOffPolicy
,
DummyCommandModePolicy
):
pass
...
...
ding/policy/ppo.py
浏览文件 @
4e833da2
...
...
@@ -220,7 +220,7 @@ class PPOPolicy(Policy):
'value_max'
:
output
[
'value'
].
max
().
item
(),
'approx_kl'
:
ppo_info
.
approx_kl
,
'clipfrac'
:
ppo_info
.
clipfrac
,
'act'
:
batch
[
'action'
].
mean
().
item
(),
'act'
:
batch
[
'action'
].
float
().
mean
().
item
(),
}
if
self
.
_continuous
:
return_info
.
update
(
...
...
@@ -326,7 +326,7 @@ class PPOPolicy(Policy):
else
:
with
torch
.
no_grad
():
last_value
=
self
.
_collect_model
.
forward
(
data
[
-
1
][
'next_obs'
].
unsqueeze
(
0
),
mode
=
'compute_critic'
data
[
-
1
][
'next_obs'
].
unsqueeze
(
0
),
mode
=
'compute_
actor_
critic'
)[
'value'
]
if
self
.
_value_norm
:
last_value
*=
self
.
_running_mean_std
.
std
...
...
@@ -458,9 +458,7 @@ class PPOOffPolicy(Policy):
gae_lambda
=
0.95
,
),
eval
=
dict
(),
# Although ppo is an on-policy algorithm, ding reuses the buffer mechanism, and clear buffer after update.
# Note replay_buffer_size must be greater than n_sample.
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000
,
),
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
10000
,
),
),
)
def
_init_learn
(
self
)
->
None
:
...
...
dizoo/classic_control/cartpole/config/__init__.py
浏览文件 @
4e833da2
...
...
@@ -2,6 +2,7 @@ from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from
.cartpole_a2c_config
import
cartpole_a2c_config
,
cartpole_a2c_create_config
from
.cartpole_impala_config
import
cartpole_impala_config
,
cartpole_impala_create_config
from
.cartpole_ppo_config
import
cartpole_ppo_config
,
cartpole_ppo_create_config
from
.cartpole_ppo_offpolicy_config
import
cartpole_ppo_offpolicy_config
,
cartpole_ppo_offpolicy_create_config
from
.cartpole_rainbow_config
import
cartpole_rainbow_config
,
cartpole_rainbow_create_config
from
.cartpole_iqn_config
import
cartpole_iqn_config
,
cartpole_iqn_create_config
from
.cartpole_c51_config
import
cartpole_c51_config
,
cartpole_c51_create_config
...
...
dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
浏览文件 @
4e833da2
...
...
@@ -10,6 +10,8 @@ cartpole_ppo_config = dict(
),
policy
=
dict
(
cuda
=
False
,
on_policy
=
True
,
continuous
=
False
,
model
=
dict
(
obs_shape
=
4
,
action_shape
=
2
,
...
...
@@ -18,7 +20,7 @@ cartpole_ppo_config = dict(
actor_head_hidden_size
=
128
,
),
learn
=
dict
(
update_per_collect
=
6
,
epoch_per_collect
=
2
,
batch_size
=
64
,
learning_rate
=
0.001
,
value_weight
=
0.5
,
...
...
@@ -26,7 +28,7 @@ cartpole_ppo_config = dict(
clip_ratio
=
0.2
,
),
collect
=
dict
(
n_sample
=
128
,
n_sample
=
256
,
unroll_len
=
1
,
discount_factor
=
0.9
,
gae_lambda
=
0.95
,
...
...
dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py
0 → 100644
浏览文件 @
4e833da2
from
easydict
import
EasyDict
cartpole_ppo_offpolicy_config
=
dict
(
exp_name
=
'cartpole_ppo_offpolicy'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
5
,
n_evaluator_episode
=
5
,
stop_value
=
195
,
),
policy
=
dict
(
on_policy
=
False
,
cuda
=
False
,
model
=
dict
(
obs_shape
=
4
,
action_shape
=
2
,
encoder_hidden_size_list
=
[
64
,
64
,
128
],
critic_head_hidden_size
=
128
,
actor_head_hidden_size
=
128
,
),
learn
=
dict
(
update_per_collect
=
6
,
batch_size
=
64
,
learning_rate
=
0.001
,
value_weight
=
0.5
,
entropy_weight
=
0.01
,
clip_ratio
=
0.2
,
),
collect
=
dict
(
n_sample
=
128
,
unroll_len
=
1
,
discount_factor
=
0.9
,
gae_lambda
=
0.95
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
5000
))
),
)
cartpole_ppo_offpolicy_config
=
EasyDict
(
cartpole_ppo_offpolicy_config
)
main_config
=
cartpole_ppo_offpolicy_config
cartpole_ppo_offpolicy_create_config
=
dict
(
env
=
dict
(
type
=
'cartpole'
,
import_names
=
[
'dizoo.classic_control.cartpole.envs.cartpole_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo_offpolicy'
),
)
cartpole_ppo_offpolicy_create_config
=
EasyDict
(
cartpole_ppo_offpolicy_create_config
)
create_config
=
cartpole_ppo_offpolicy_create_config
dizoo/classic_control/cartpole/config/cartpole_ppo_rnd_config.py
浏览文件 @
4e833da2
...
...
@@ -17,6 +17,7 @@ cartpole_ppo_rnd_config = dict(
),
policy
=
dict
(
cuda
=
False
,
on_policy
=
False
,
model
=
dict
(
obs_shape
=
4
,
action_shape
=
2
,
...
...
@@ -48,7 +49,7 @@ cartpole_ppo_rnd_create_config = dict(
import_names
=
[
'dizoo.classic_control.cartpole.envs.cartpole_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo'
),
policy
=
dict
(
type
=
'ppo
_offpolicy
'
),
reward_model
=
dict
(
type
=
'rnd'
),
)
cartpole_ppo_rnd_create_config
=
EasyDict
(
cartpole_ppo_rnd_create_config
)
...
...
dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py
浏览文件 @
4e833da2
...
...
@@ -3,7 +3,7 @@ import gym
from
tensorboardX
import
SummaryWriter
from
ding.config
import
compile_config
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
from
ding.envs
import
BaseEnvManager
,
DingEnvWrapper
from
ding.policy
import
PPOPolicy
from
ding.model
import
VAC
...
...
@@ -17,14 +17,7 @@ def wrapped_cartpole_env():
def
main
(
cfg
,
seed
=
0
,
max_iterations
=
int
(
1e10
)):
cfg
=
compile_config
(
cfg
,
BaseEnvManager
,
PPOPolicy
,
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
,
save_cfg
=
True
cfg
,
BaseEnvManager
,
PPOPolicy
,
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
save_cfg
=
True
)
collector_env_num
,
evaluator_env_num
=
cfg
.
env
.
collector_env_num
,
cfg
.
env
.
evaluator_env_num
collector_env
=
BaseEnvManager
(
env_fn
=
[
wrapped_cartpole_env
for
_
in
range
(
collector_env_num
)],
cfg
=
cfg
.
env
.
manager
)
...
...
@@ -44,7 +37,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
evaluator
=
BaseSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
replay_buffer
=
NaiveReplayBuffer
(
cfg
.
policy
.
other
.
replay_buffer
,
exp_name
=
cfg
.
exp_name
)
for
_
in
range
(
max_iterations
):
if
evaluator
.
should_eval
(
learner
.
train_iter
):
...
...
@@ -52,13 +44,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
if
stop
:
break
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
)
assert
all
([
len
(
c
)
==
0
for
c
in
collector
.
_traj_buffer
.
values
()])
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
train_data
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
if
train_data
is
not
None
:
learner
.
train
(
train_data
,
collector
.
envstep
)
replay_buffer
.
clear
()
learner
.
train
(
new_data
,
collector
.
envstep
)
if
__name__
==
"__main__"
:
...
...
dizoo/classic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py
0 → 100644
浏览文件 @
4e833da2
import
os
import
gym
from
tensorboardX
import
SummaryWriter
from
ding.config
import
compile_config
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
from
ding.envs
import
BaseEnvManager
,
DingEnvWrapper
from
ding.policy
import
PPOOffPolicy
from
ding.model
import
VAC
from
ding.utils
import
set_pkg_seed
,
deep_merge_dicts
from
dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config
import
cartpole_ppo_offpolicy_config
def
wrapped_cartpole_env
():
return
DingEnvWrapper
(
gym
.
make
(
'CartPole-v0'
))
def
main
(
cfg
,
seed
=
0
,
max_iterations
=
int
(
1e10
)):
cfg
=
compile_config
(
cfg
,
BaseEnvManager
,
PPOOffPolicy
,
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
,
save_cfg
=
True
)
collector_env_num
,
evaluator_env_num
=
cfg
.
env
.
collector_env_num
,
cfg
.
env
.
evaluator_env_num
collector_env
=
BaseEnvManager
(
env_fn
=
[
wrapped_cartpole_env
for
_
in
range
(
collector_env_num
)],
cfg
=
cfg
.
env
.
manager
)
evaluator_env
=
BaseEnvManager
(
env_fn
=
[
wrapped_cartpole_env
for
_
in
range
(
evaluator_env_num
)],
cfg
=
cfg
.
env
.
manager
)
collector_env
.
seed
(
seed
)
evaluator_env
.
seed
(
seed
,
dynamic_seed
=
False
)
set_pkg_seed
(
seed
,
use_cuda
=
cfg
.
policy
.
cuda
)
model
=
VAC
(
**
cfg
.
policy
.
model
)
policy
=
PPOOffPolicy
(
cfg
.
policy
,
model
=
model
)
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./{}/log/'
.
format
(
cfg
.
exp_name
),
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
collector
=
SampleCollector
(
cfg
.
policy
.
collect
.
collector
,
collector_env
,
policy
.
collect_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
evaluator
=
BaseSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
replay_buffer
=
NaiveReplayBuffer
(
cfg
.
policy
.
other
.
replay_buffer
,
exp_name
=
cfg
.
exp_name
)
for
_
in
range
(
max_iterations
):
if
evaluator
.
should_eval
(
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
learner
.
save_checkpoint
,
learner
.
train_iter
,
collector
.
envstep
)
if
stop
:
break
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
)
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
for
i
in
range
(
cfg
.
policy
.
learn
.
update_per_collect
):
train_data
=
replay_buffer
.
sample
(
learner
.
policy
.
get_attribute
(
'batch_size'
),
learner
.
train_iter
)
if
train_data
is
not
None
:
learner
.
train
(
train_data
,
collector
.
envstep
)
if
__name__
==
"__main__"
:
main
(
cartpole_ppo_offpolicy_config
)
dizoo/classic_control/cartpole/entry/cartpole_ppo_rnd_main.py
浏览文件 @
4e833da2
...
...
@@ -5,7 +5,7 @@ from tensorboardX import SummaryWriter
from
ding.config
import
compile_config
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
from
ding.envs
import
BaseEnvManager
,
DingEnvWrapper
from
ding.policy
import
PPOPolicy
from
ding.policy
import
PPO
Off
Policy
from
ding.model
import
VAC
from
ding.utils
import
set_pkg_seed
,
deep_merge_dicts
from
ding.reward_model
import
RndRewardModel
...
...
@@ -20,7 +20,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
cfg
=
compile_config
(
cfg
,
BaseEnvManager
,
PPOPolicy
,
PPO
Off
Policy
,
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
...
...
@@ -37,7 +37,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
set_pkg_seed
(
seed
,
use_cuda
=
cfg
.
policy
.
cuda
)
model
=
VAC
(
**
cfg
.
policy
.
model
)
policy
=
PPOPolicy
(
cfg
.
policy
,
model
=
model
)
policy
=
PPO
Off
Policy
(
cfg
.
policy
,
model
=
model
)
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./{}/log/'
.
format
(
cfg
.
exp_name
),
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
collector
=
SampleCollector
(
...
...
@@ -55,7 +55,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
if
stop
:
break
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
)
assert
all
([
len
(
c
)
==
0
for
c
in
collector
.
_traj_buffer
.
values
()])
replay_buffer
.
push
(
new_data
,
cur_collector_envstep
=
collector
.
envstep
)
reward_model
.
collect_data
(
new_data
)
reward_model
.
train
()
...
...
@@ -65,7 +64,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
reward_model
.
estimate
(
train_data
)
if
train_data
is
not
None
:
learner
.
train
(
train_data
,
collector
.
envstep
)
replay_buffer
.
clear
()
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录