Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
4b7e50c4
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4b7e50c4
编写于
10月 15, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(nyz): remove torch in env and correct dizoo yapf format
上级
f537adf0
变更
56
隐藏空白更改
内联
并排
Showing
56 changed file
with
172 addition
and
197 deletion
+172
-197
dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
...i/config/serial/pong/pong_qrdqn_generation_data_config.py
+1
-1
dizoo/atari/config/serial/pong/pong_sqil_config.py
dizoo/atari/config/serial/pong/pong_sqil_config.py
+4
-9
dizoo/atari/config/serial/pong/pong_sql_config.py
dizoo/atari/config/serial/pong/pong_sql_config.py
+3
-9
dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
...config/serial/qbert/qbert_qrdqn_generation_data_config.py
+1
-1
dizoo/atari/config/serial/qbert/qbert_sqil_config.py
dizoo/atari/config/serial/qbert/qbert_sqil_config.py
+3
-2
dizoo/atari/config/serial/qbert/qbert_sql_config.py
dizoo/atari/config/serial/qbert/qbert_sql_config.py
+2
-2
dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py
.../config/serial/spaceinvaders/spaceinvaders_sqil_config.py
+4
-9
dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py
...i/config/serial/spaceinvaders/spaceinvaders_sql_config.py
+3
-9
dizoo/atari/entry/atari_ppg_main.py
dizoo/atari/entry/atari_ppg_main.py
+6
-2
dizoo/atari/entry/pong_cql_main.py
dizoo/atari/entry/pong_cql_main.py
+7
-2
dizoo/atari/entry/qbert_cql_main.py
dizoo/atari/entry/qbert_cql_main.py
+7
-2
dizoo/atari/envs/atari_env.py
dizoo/atari/envs/atari_env.py
+0
-1
dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
+3
-3
dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py
dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py
+8
-8
dizoo/box2d/lunarlander/envs/lunarlander_env.py
dizoo/box2d/lunarlander/envs/lunarlander_env.py
+3
-3
dizoo/bsuite/envs/bsuite_env.py
dizoo/bsuite/envs/bsuite_env.py
+2
-2
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
+1
-2
dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
.../cartpole/config/cartpole_qrdqn_generation_data_config.py
+2
-1
dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
+8
-2
dizoo/classic_control/cartpole/envs/cartpole_env.py
dizoo/classic_control/cartpole/envs/cartpole_env.py
+4
-5
dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py
...lum/config/pendulum_sac_data_generation_default_config.py
+1
-1
dizoo/classic_control/pendulum/entry/pendulum_cql_main.py
dizoo/classic_control/pendulum/entry/pendulum_cql_main.py
+7
-2
dizoo/classic_control/pendulum/envs/pendulum_env.py
dizoo/classic_control/pendulum/envs/pendulum_env.py
+3
-4
dizoo/common/policy/md_ppo.py
dizoo/common/policy/md_ppo.py
+1
-1
dizoo/competitive_rl/envs/competitive_rl_env.py
dizoo/competitive_rl/envs/competitive_rl_env.py
+2
-3
dizoo/d4rl/envs/d4rl_env.py
dizoo/d4rl/envs/d4rl_env.py
+7
-7
dizoo/gym_hybrid/envs/gym_hybrid_env.py
dizoo/gym_hybrid/envs/gym_hybrid_env.py
+3
-4
dizoo/image_classification/entry/imagenet_res18_config.py
dizoo/image_classification/entry/imagenet_res18_config.py
+1
-10
dizoo/league_demo/league_demo_ppo_config.py
dizoo/league_demo/league_demo_ppo_config.py
+0
-1
dizoo/league_demo/league_demo_ppo_main.py
dizoo/league_demo/league_demo_ppo_main.py
+1
-0
dizoo/minigrid/config/minigrid_r2d2_config.py
dizoo/minigrid/config/minigrid_r2d2_config.py
+3
-3
dizoo/minigrid/envs/minigrid_env.py
dizoo/minigrid/envs/minigrid_env.py
+3
-3
dizoo/mujoco/config/hopper_sac_data_generation_default_config.py
...ujoco/config/hopper_sac_data_generation_default_config.py
+1
-1
dizoo/mujoco/entry/mujoco_cql_generation_main.py
dizoo/mujoco/entry/mujoco_cql_generation_main.py
+7
-2
dizoo/mujoco/envs/mujoco_env.py
dizoo/mujoco/envs/mujoco_env.py
+4
-5
dizoo/multiagent_particle/envs/particle_env.py
dizoo/multiagent_particle/envs/particle_env.py
+4
-5
dizoo/multiagent_particle/envs/test_particle_env.py
dizoo/multiagent_particle/envs/test_particle_env.py
+8
-11
dizoo/overcooked/config/__init__.py
dizoo/overcooked/config/__init__.py
+1
-1
dizoo/overcooked/config/overcooked_demo_ppo_config.py
dizoo/overcooked/config/overcooked_demo_ppo_config.py
+1
-1
dizoo/overcooked/entry/overcooked_selfplay_ppo_main.py
dizoo/overcooked/entry/overcooked_selfplay_ppo_main.py
+1
-0
dizoo/overcooked/envs/__init__.py
dizoo/overcooked/envs/__init__.py
+1
-1
dizoo/overcooked/envs/test_overcooked_env.py
dizoo/overcooked/envs/test_overcooked_env.py
+1
-2
dizoo/pomdp/envs/atari_env.py
dizoo/pomdp/envs/atari_env.py
+5
-5
dizoo/procgen/coinrun/entry/coinrun_dqn_config.py
dizoo/procgen/coinrun/entry/coinrun_dqn_config.py
+3
-5
dizoo/procgen/coinrun/entry/coinrun_ppo_config.py
dizoo/procgen/coinrun/entry/coinrun_ppo_config.py
+3
-5
dizoo/procgen/coinrun/envs/coinrun_env.py
dizoo/procgen/coinrun/envs/coinrun_env.py
+4
-7
dizoo/procgen/maze/entry/maze_dqn_config.py
dizoo/procgen/maze/entry/maze_dqn_config.py
+4
-6
dizoo/procgen/maze/entry/maze_ppo_config.py
dizoo/procgen/maze/entry/maze_ppo_config.py
+3
-6
dizoo/procgen/maze/envs/maze_env.py
dizoo/procgen/maze/envs/maze_env.py
+5
-6
dizoo/pybullet/envs/pybullet_env.py
dizoo/pybullet/envs/pybullet_env.py
+5
-6
dizoo/pybullet/envs/pybullet_wrappers.py
dizoo/pybullet/envs/pybullet_wrappers.py
+0
-1
dizoo/slime_volley/envs/slime_volley_env.py
dizoo/slime_volley/envs/slime_volley_env.py
+3
-3
dizoo/smac/config/smac_3s5z_wqmix_config.py
dizoo/smac/config/smac_3s5z_wqmix_config.py
+1
-1
dizoo/smac/config/smac_5m6m_wqmix_config.py
dizoo/smac/config/smac_5m6m_wqmix_config.py
+1
-1
dizoo/smac/config/smac_MMM2_wqmix_config.py
dizoo/smac/config/smac_MMM2_wqmix_config.py
+1
-1
dizoo/smac/config/smac_MMM_wqmix_config.py
dizoo/smac/config/smac_MMM_wqmix_config.py
+1
-1
未找到文件。
dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
浏览文件 @
4b7e50c4
...
...
@@ -28,7 +28,7 @@ pong_qrdqn_config = dict(
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
learner
=
dict
(
learner
=
dict
(
load_path
=
'./expert/ckpt/ckpt_best.pth.tar'
,
hook
=
dict
(
load_ckpt_before_run
=
'./expert/ckpt/ckpt_best.pth.tar'
,
...
...
dizoo/atari/config/serial/pong/pong_sqil_config.py
浏览文件 @
4b7e50c4
...
...
@@ -23,14 +23,9 @@ pong_sqil_config = dict(
),
nstep
=
3
,
discount_factor
=
0.99
,
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.12
),
collect
=
dict
(
n_sample
=
96
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.12
),
collect
=
dict
(
n_sample
=
96
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
other
=
dict
(
eps
=
dict
(
type
=
'exp'
,
...
...
@@ -49,7 +44,7 @@ pong_sqil_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'sql'
),
)
pong_sqil_create_config
=
EasyDict
(
pong_sqil_create_config
)
...
...
dizoo/atari/config/serial/pong/pong_sql_config.py
浏览文件 @
4b7e50c4
...
...
@@ -23,14 +23,8 @@ pong_sql_config = dict(
),
nstep
=
3
,
discount_factor
=
0.99
,
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.12
),
collect
=
dict
(
n_sample
=
96
,
demonstration_info_path
=
None
),
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.12
),
collect
=
dict
(
n_sample
=
96
,
demonstration_info_path
=
None
),
other
=
dict
(
eps
=
dict
(
type
=
'exp'
,
...
...
@@ -49,7 +43,7 @@ pong_sql_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'sql'
),
)
pong_sql_create_config
=
EasyDict
(
pong_sql_create_config
)
...
...
dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
浏览文件 @
4b7e50c4
...
...
@@ -28,7 +28,7 @@ qbert_qrdqn_config = dict(
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
learner
=
dict
(
learner
=
dict
(
load_path
=
'./expert/ckpt/ckpt_best.pth.tar'
,
hook
=
dict
(
load_ckpt_before_run
=
'./expert/ckpt/ckpt_best.pth.tar'
,
...
...
dizoo/atari/config/serial/qbert/qbert_sqil_config.py
浏览文件 @
4b7e50c4
...
...
@@ -28,7 +28,8 @@ qbert_dqn_config = dict(
learning_rate
=
0.0001
,
target_update_freq
=
500
,
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -48,7 +49,7 @@ qbert_dqn_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'subprocess'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'dqn'
),
)
qbert_dqn_create_config
=
EasyDict
(
qbert_dqn_create_config
)
...
...
dizoo/atari/config/serial/qbert/qbert_sql_config.py
浏览文件 @
4b7e50c4
...
...
@@ -28,7 +28,7 @@ qbert_dqn_config = dict(
learning_rate
=
0.0001
,
target_update_freq
=
500
,
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
None
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
None
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -48,7 +48,7 @@ qbert_dqn_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'subprocess'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'sql'
),
)
qbert_dqn_create_config
=
EasyDict
(
qbert_dqn_create_config
)
...
...
dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py
浏览文件 @
4b7e50c4
...
...
@@ -23,14 +23,9 @@ space_invaders_sqil_config = dict(
),
nstep
=
3
,
discount_factor
=
0.99
,
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.1
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.1
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
'path'
),
#Users should add their own path here (path should lead to a well-trained model)
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -50,7 +45,7 @@ space_invaders_sqil_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'sql'
),
)
space_invaders_sqil_create_config
=
EasyDict
(
space_invaders_sqil_create_config
)
...
...
dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py
浏览文件 @
4b7e50c4
...
...
@@ -23,14 +23,8 @@ space_invaders_sql_config = dict(
),
nstep
=
3
,
discount_factor
=
0.99
,
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.1
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
None
),
learn
=
dict
(
update_per_collect
=
10
,
batch_size
=
32
,
learning_rate
=
0.0001
,
target_update_freq
=
500
,
alpha
=
0.1
),
collect
=
dict
(
n_sample
=
100
,
demonstration_info_path
=
None
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -50,7 +44,7 @@ space_invaders_sql_create_config = dict(
type
=
'atari'
,
import_names
=
[
'dizoo.atari.envs.atari_env'
],
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
env_manager
=
dict
(
type
=
'base'
,
force_reproducibility
=
True
),
policy
=
dict
(
type
=
'sql'
),
)
space_invaders_sql_create_config
=
EasyDict
(
space_invaders_sql_create_config
)
...
...
dizoo/atari/entry/atari_ppg_main.py
浏览文件 @
4b7e50c4
...
...
@@ -31,8 +31,12 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
)
collector_env_cfg
=
AtariEnv
.
create_collector_env_cfg
(
cfg
.
env
)
evaluator_env_cfg
=
AtariEnv
.
create_evaluator_env_cfg
(
cfg
.
env
)
collector_env
=
SyncSubprocessEnvManager
(
env_fn
=
[
partial
(
AtariEnv
,
cfg
=
c
)
for
c
in
collector_env_cfg
],
cfg
=
cfg
.
env
.
manager
)
evaluator_env
=
SyncSubprocessEnvManager
(
env_fn
=
[
partial
(
AtariEnv
,
cfg
=
c
)
for
c
in
evaluator_env_cfg
],
cfg
=
cfg
.
env
.
manager
)
collector_env
=
SyncSubprocessEnvManager
(
env_fn
=
[
partial
(
AtariEnv
,
cfg
=
c
)
for
c
in
collector_env_cfg
],
cfg
=
cfg
.
env
.
manager
)
evaluator_env
=
SyncSubprocessEnvManager
(
env_fn
=
[
partial
(
AtariEnv
,
cfg
=
c
)
for
c
in
evaluator_env_cfg
],
cfg
=
cfg
.
env
.
manager
)
collector_env
.
seed
(
seed
)
evaluator_env
.
seed
(
seed
,
dynamic_seed
=
False
)
...
...
dizoo/atari/entry/pong_cql_main.py
浏览文件 @
4b7e50c4
...
...
@@ -28,8 +28,13 @@ def generate(args):
main_config
.
policy
.
collect
.
save_path
=
'./pong/expert.pkl'
config
=
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
def
train_expert
(
args
):
...
...
dizoo/atari/entry/qbert_cql_main.py
浏览文件 @
4b7e50c4
...
...
@@ -28,8 +28,13 @@ def generate(args):
main_config
.
policy
.
collect
.
save_path
=
'./qbert/expert.pkl'
config
=
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
def
train_expert
(
args
):
...
...
dizoo/atari/envs/atari_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
List
,
Union
,
Sequence
import
copy
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
...
...
dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
浏览文件 @
4b7e50c4
...
...
@@ -4,7 +4,7 @@ import gym
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
FrameStack
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.envs.common.common_function
import
affine_transform
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -53,7 +53,7 @@ class BipedalWalkerEnv(BaseEnv):
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
if
self
.
_act_scale
:
action_range
=
self
.
info
().
act_space
.
value
action
=
affine_transform
(
action
,
min_val
=
action_range
[
'min'
],
max_val
=
action_range
[
'max'
])
...
...
@@ -67,7 +67,7 @@ class BipedalWalkerEnv(BaseEnv):
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py
浏览文件 @
4b7e50c4
...
...
@@ -11,31 +11,31 @@ lunarlander_r2d2_config = dict(
stop_value
=
195
,
),
policy
=
dict
(
cuda
=
False
,
cuda
=
False
,
on_policy
=
False
,
priority
=
False
,
priority
=
False
,
model
=
dict
(
obs_shape
=
8
,
action_shape
=
4
,
encoder_hidden_size_list
=
[
128
,
128
,
64
],
),
discount_factor
=
0.997
,
burnin_step
=
20
,
discount_factor
=
0.997
,
burnin_step
=
20
,
nstep
=
5
,
# (int) the whole sequence length to unroll the RNN network minus
# the timesteps of burnin part,
# i.e., <the whole sequence length> = <burnin_step> + <unroll_len>
unroll_len
=
80
,
unroll_len
=
80
,
learn
=
dict
(
# according to the R2D2 paper, actor parameter update interval is 400
# environment timesteps, and in per collect phase, we collect 32 sequence
# samples, the length of each samlpe sequence is <burnin_step> + <unroll_len>,
# which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
# in most environments
update_per_collect
=
8
,
update_per_collect
=
8
,
batch_size
=
64
,
learning_rate
=
0.0005
,
target_update_freq
=
2500
,
target_update_freq
=
2500
,
),
collect
=
dict
(
n_sample
=
32
,
...
...
@@ -48,7 +48,7 @@ lunarlander_r2d2_config = dict(
start
=
0.95
,
end
=
0.05
,
decay
=
10000
,
),
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
)
),
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
)
),
),
)
...
...
dizoo/box2d/lunarlander/envs/lunarlander_env.py
浏览文件 @
4b7e50c4
...
...
@@ -4,7 +4,7 @@ import gym
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -45,14 +45,14 @@ class LunarLanderEnv(BaseEnv):
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
rew
=
float
(
rew
)
self
.
_final_eval_reward
+=
rew
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
dizoo/bsuite/envs/bsuite_env.py
浏览文件 @
4b7e50c4
...
...
@@ -5,7 +5,7 @@ import gym
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
import
bsuite
...
...
@@ -147,7 +147,7 @@ class BSuiteEnv(BaseEnv):
if
obs
.
shape
[
0
]
==
1
:
obs
=
obs
[
0
]
obs
=
to_ndarray
(
obs
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
浏览文件 @
4b7e50c4
...
...
@@ -41,8 +41,7 @@ cartpole_discrete_cql_config = dict(
start
=
0.95
,
end
=
0.1
,
decay
=
10000
,
),
replay_buffer
=
dict
(
replay_buffer_size
=
20000
,
)
),
replay_buffer
=
dict
(
replay_buffer_size
=
20000
,
)
),
),
)
...
...
dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
浏览文件 @
4b7e50c4
...
...
@@ -47,7 +47,8 @@ cartpole_qrdqn_generation_data_config = dict(
end
=
0.1
,
decay
=
10000
,
collect
=
0.2
,
),
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
)
),
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
)
),
),
)
...
...
dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
浏览文件 @
4b7e50c4
...
...
@@ -29,8 +29,14 @@ def generate(args):
main_config
.
policy
.
collect
.
data_type
=
'hdf5'
config
=
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
def
train_expert
(
args
):
from
dizoo.classic_control.cartpole.config.cartpole_qrdqn_config
import
main_config
,
create_config
...
...
dizoo/classic_control/cartpole/envs/cartpole_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
List
,
Union
,
Optional
import
time
import
gym
import
torch
import
copy
import
numpy
as
np
from
easydict
import
EasyDict
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -34,7 +33,7 @@ class CartPoleEnv(BaseEnv):
self
.
_init_flag
=
False
self
.
_replay_path
=
None
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
'CartPole-v0'
)
if
self
.
_replay_path
is
not
None
:
...
...
@@ -65,13 +64,13 @@ class CartPoleEnv(BaseEnv):
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py
浏览文件 @
4b7e50c4
...
...
@@ -33,7 +33,7 @@ pendulum_sac_data_genearation_default_config = dict(
discount_factor
=
0.99
,
alpha
=
0.2
,
auto_alpha
=
False
,
learner
=
dict
(
learner
=
dict
(
load_path
=
'./default_experiment/ckpt/ckpt_best.pth.tar'
,
hook
=
dict
(
load_ckpt_before_run
=
'./default_experiment/ckpt/ckpt_best.pth.tar'
,
...
...
dizoo/classic_control/pendulum/entry/pendulum_cql_main.py
浏览文件 @
4b7e50c4
...
...
@@ -21,8 +21,13 @@ def eval_ckpt(args):
def
generate
(
args
):
config
=
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
def
train_expert
(
args
):
...
...
dizoo/classic_control/pendulum/envs/pendulum_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
Union
,
Optional
import
gym
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.envs.common.common_function
import
affine_transform
from
ding.utils
import
ENV_REGISTRY
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
@
ENV_REGISTRY
.
register
(
'pendulum'
)
...
...
@@ -19,7 +18,7 @@ class PendulumEnv(BaseEnv):
self
.
_init_flag
=
False
self
.
_replay_path
=
None
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
'Pendulum-v0'
)
if
self
.
_replay_path
is
not
None
:
...
...
@@ -55,7 +54,7 @@ class PendulumEnv(BaseEnv):
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
dizoo/common/policy/md_ppo.py
浏览文件 @
4b7e50c4
...
...
@@ -165,7 +165,7 @@ class MultiDiscretePPOOffPolicy(PPOOffPolicy):
avg_entropy_loss
=
sum
([
item
.
entropy_loss
for
item
in
loss_list
])
/
action_num
avg_approx_kl
=
sum
([
item
.
approx_kl
for
item
in
info_list
])
/
action_num
avg_clipfrac
=
sum
([
item
.
clipfrac
for
item
in
info_list
])
/
action_num
wv
,
we
=
self
.
_value_weight
,
self
.
_entropy_weight
total_loss
=
avg_policy_loss
+
wv
*
avg_value_loss
-
we
*
avg_entropy_loss
...
...
dizoo/competitive_rl/envs/competitive_rl_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
Union
,
List
import
copy
import
torch
import
numpy
as
np
import
gym
import
competitive_rl
...
...
@@ -8,7 +7,7 @@ import competitive_rl
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.envs.common.common_function
import
affine_transform
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
.competitive_rl_env_wrapper
import
BuiltinOpponentWrapper
,
wrap_env
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -108,7 +107,7 @@ class CompetitiveRlEnv(BaseEnv):
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
def
step
(
self
,
action
:
Union
[
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
action
=
to_ndarray
(
action
)
action
=
self
.
process_action
(
action
)
# process
...
...
dizoo/d4rl/envs/d4rl_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
Union
,
List
import
copy
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.envs.common.common_function
import
affine_transform
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
.d4rl_wrappers
import
wrap_d4rl
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -320,7 +319,7 @@ class D4RLEnv(BaseEnv):
self
.
_use_act_scale
=
cfg
.
use_act_scale
self
.
_init_flag
=
False
def
reset
(
self
)
->
torch
.
FloatTensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
self
.
_make_env
(
only_info
=
False
)
self
.
_init_flag
=
True
...
...
@@ -344,7 +343,7 @@ class D4RLEnv(BaseEnv):
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
def
step
(
self
,
action
:
Union
[
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
action
=
to_ndarray
(
action
)
if
self
.
_use_act_scale
:
action_range
=
self
.
info
().
act_space
.
value
...
...
@@ -352,7 +351,7 @@ class D4RLEnv(BaseEnv):
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
obs
=
to_ndarray
(
obs
).
astype
(
'float32'
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
@@ -369,8 +368,9 @@ class D4RLEnv(BaseEnv):
info
.
rew_space
.
shape
=
rew_shape
return
info
else
:
raise
NotImplementedError
(
'{} not found in D4RL_INFO_DICT [{}]'
.
format
(
self
.
_cfg
.
env_id
,
D4RL_INFO_DICT
.
keys
()))
raise
NotImplementedError
(
'{} not found in D4RL_INFO_DICT [{}]'
.
format
(
self
.
_cfg
.
env_id
,
D4RL_INFO_DICT
.
keys
())
)
def
_make_env
(
self
,
only_info
=
False
):
return
wrap_d4rl
(
...
...
dizoo/gym_hybrid/envs/gym_hybrid_env.py
浏览文件 @
4b7e50c4
...
...
@@ -2,13 +2,12 @@ from typing import Any, List, Union, Optional
import
time
import
gym
import
gym_hybrid
import
torch
import
copy
import
numpy
as
np
from
easydict
import
EasyDict
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -23,7 +22,7 @@ class GymHybridEnv(BaseEnv):
self
.
_init_flag
=
False
self
.
_replay_path
=
None
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
self
.
_env_id
)
if
self
.
_replay_path
is
not
None
:
...
...
@@ -58,7 +57,7 @@ class GymHybridEnv(BaseEnv):
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
info
[
'action_args_mask'
]
=
np
.
array
([[
1
,
0
],
[
0
,
1
],
[
0
,
0
]])
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
dizoo/image_classification/entry/imagenet_res18_config.py
浏览文件 @
4b7e50c4
...
...
@@ -28,16 +28,7 @@ imagenet_res18_config = dict(
eval_data_path
=
'/mnt/lustre/share/images/val'
,
),
eval
=
dict
(
batch_size
=
32
,
evaluator
=
dict
(
eval_freq
=
1
,
multi_gpu
=
True
,
stop_value
=
dict
(
loss
=
0.5
,
acc1
=
75.0
,
acc5
=
95.0
)
)
batch_size
=
32
,
evaluator
=
dict
(
eval_freq
=
1
,
multi_gpu
=
True
,
stop_value
=
dict
(
loss
=
0.5
,
acc1
=
75.0
,
acc5
=
95.0
))
),
),
env
=
dict
(),
...
...
dizoo/league_demo/league_demo_ppo_config.py
浏览文件 @
4b7e50c4
...
...
@@ -87,7 +87,6 @@ league_demo_ppo_config = dict(
),
),
),
),
)
league_demo_ppo_config
=
EasyDict
(
league_demo_ppo_config
)
...
...
dizoo/league_demo/league_demo_ppo_main.py
浏览文件 @
4b7e50c4
...
...
@@ -37,6 +37,7 @@ class EvalPolicy1:
class
EvalPolicy2
:
def
forward
(
self
,
data
:
dict
)
->
dict
:
return
{
env_id
:
{
...
...
dizoo/minigrid/config/minigrid_r2d2_config.py
浏览文件 @
4b7e50c4
...
...
@@ -33,10 +33,10 @@ minigrid_r2d2_config = dict(
# samples, the length of each samlpe sequence is <burnin_step> + <unroll_len>,
# which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
# in most environments
update_per_collect
=
8
,
update_per_collect
=
8
,
batch_size
=
64
,
learning_rate
=
0.0005
,
target_update_freq
=
2500
,
target_update_freq
=
2500
,
),
collect
=
dict
(
n_sample
=
32
,
...
...
@@ -67,4 +67,4 @@ minigrid_r2d2_create_config = EasyDict(minigrid_r2d2_create_config)
create_config
=
minigrid_r2d2_create_config
if
__name__
==
"__main__"
:
serial_pipeline
([
main_config
,
create_config
],
seed
=
0
)
\ No newline at end of file
serial_pipeline
([
main_config
,
create_config
],
seed
=
0
)
dizoo/minigrid/envs/minigrid_env.py
浏览文件 @
4b7e50c4
...
...
@@ -13,7 +13,7 @@ from gym_minigrid.window import Window
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
MiniGridEnvInfo
=
namedtuple
(
...
...
@@ -170,7 +170,7 @@ class MiniGridEnv(BaseEnv):
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
if
self
.
_save_replay
:
self
.
_frames
.
append
(
self
.
_env
.
render
(
mode
=
'rgb_array'
))
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
...
...
@@ -190,7 +190,7 @@ class MiniGridEnv(BaseEnv):
self
.
display_frames_as_gif
(
self
.
_frames
,
path
)
self
.
_save_replay_count
+=
1
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
MiniGridEnvInfo
:
...
...
dizoo/mujoco/config/hopper_sac_data_generation_default_config.py
浏览文件 @
4b7e50c4
...
...
@@ -35,7 +35,7 @@ hopper_sac_data_genearation_default_config = dict(
alpha
=
0.2
,
reparameterization
=
True
,
auto_alpha
=
False
,
learner
=
dict
(
learner
=
dict
(
load_path
=
'./default_experiment/ckpt/ckpt_best.pth.tar'
,
hook
=
dict
(
load_ckpt_before_run
=
'./default_experiment/ckpt/ckpt_best.pth.tar'
,
...
...
dizoo/mujoco/entry/mujoco_cql_generation_main.py
浏览文件 @
4b7e50c4
...
...
@@ -12,8 +12,13 @@ def eval_ckpt(args):
def
generate
(
args
):
config
=
copy
.
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
if
__name__
==
"__main__"
:
...
...
dizoo/mujoco/envs/mujoco_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
Union
,
List
import
copy
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.envs.common.common_function
import
affine_transform
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
.mujoco_wrappers
import
wrap_mujoco
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -265,7 +264,7 @@ class MujocoEnv(BaseEnv):
self
.
_use_act_scale
=
cfg
.
use_act_scale
self
.
_init_flag
=
False
def
reset
(
self
)
->
torch
.
FloatTensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
self
.
_make_env
(
only_info
=
False
)
self
.
_init_flag
=
True
...
...
@@ -289,7 +288,7 @@ class MujocoEnv(BaseEnv):
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
def
step
(
self
,
action
:
Union
[
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
action
=
to_ndarray
(
action
)
if
self
.
_use_act_scale
:
action_range
=
self
.
info
().
act_space
.
value
...
...
@@ -297,7 +296,7 @@ class MujocoEnv(BaseEnv):
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
obs
=
to_ndarray
(
obs
).
astype
(
'float32'
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
dizoo/multiagent_particle/envs/particle_env.py
浏览文件 @
4b7e50c4
...
...
@@ -3,12 +3,11 @@ from typing import Any, Optional
from
easydict
import
EasyDict
import
copy
import
numpy
as
np
import
torch
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.utils
import
ENV_REGISTRY
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
dizoo.multiagent_particle.envs.make_env
import
make_env
from
dizoo.multiagent_particle.envs.multiagent.multi_discrete
import
MultiDiscrete
import
gym
...
...
@@ -29,7 +28,7 @@ class ParticleEnv(BaseEnv):
self
.
_env
.
force_discrete_action
=
True
self
.
agent_num
=
self
.
_env
.
n
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
self
.
_step_count
=
0
if
hasattr
(
self
,
'_seed'
):
# Note: the real env instance only has a empty seed method, only pass
...
...
@@ -153,7 +152,7 @@ class ModifiedPredatorPrey(BaseEnv):
self
.
global_obs_dim
=
self
.
_n_agent
*
2
+
self
.
_num_landmarks
*
2
+
self
.
_n_agent
*
2
self
.
obs_alone_dim
=
2
+
2
+
(
self
.
_num_landmarks
)
*
2
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
self
.
_step_count
=
0
self
.
_sum_reward
=
0
# if hasattr(self, '_seed'):
...
...
@@ -314,7 +313,7 @@ class CooperativeNavigation(BaseEnv):
self
.
global_obs_dim
=
self
.
_n_agent
*
2
+
self
.
_num_landmarks
*
2
+
self
.
_n_agent
*
2
self
.
obs_alone_dim
=
2
+
2
+
(
self
.
_num_landmarks
)
*
2
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
self
.
_step_count
=
0
self
.
_sum_reward
=
0
if
hasattr
(
self
,
'_seed'
):
...
...
dizoo/multiagent_particle/envs/test_particle_env.py
浏览文件 @
4b7e50c4
import
pytest
import
torch
import
numpy
as
np
from
dizoo.multiagent_particle.envs
import
ParticleEnv
,
CooperativeNavigation
use_discrete
=
[
True
,
False
]
...
...
@@ -57,25 +57,22 @@ class TestParticleEnv:
min_val
,
max_val
=
act_val
[
'min'
],
act_val
[
'max'
]
if
act_sp
.
shape
==
(
1
,
):
if
discrete_action
:
random_action
.
append
(
torch
.
randint
(
min_val
,
max_val
,
act_sp
.
shape
))
random_action
.
append
(
np
.
random
.
randint
(
min_val
,
max_val
,
act_sp
.
shape
))
else
:
random_action
.
append
(
torch
.
rand
(
max_val
+
1
-
min_val
,
))
random_action
.
append
(
np
.
random
.
random
(
max_val
+
1
-
min_val
,
))
else
:
# print(act_sp.shape)
if
discrete_action
:
random_action
.
append
(
torch
.
cat
(
[
torch
.
randint
(
min_val
[
t
],
max_val
[
t
],
(
1
,
))
for
t
in
range
(
act_sp
.
shape
[
0
])]
np
.
concatenate
(
[
np
.
random
.
randint
(
min_val
[
t
],
max_val
[
t
],
(
1
,
))
for
t
in
range
(
act_sp
.
shape
[
0
])]
)
# [torch.randint(min_val[t], max_val[t], (1, )) for t in range(act_sp.shape[0])]
)
else
:
# print("i = ", i)
# print('randon_action = ', random_action)
# print([torch.rand(max_val[t]+1 - min_val[t], ) for t in range(act_sp.shape[0])])
random_action
.
append
(
# torch.stack([torch.rand(max_val[t]+1 - min_val[t], ) for t in range(act_sp.shape[0])])
[
torch
.
rand
(
max_val
[
t
]
+
1
-
min_val
[
t
],
)
for
t
in
range
(
act_sp
.
shape
[
0
])]
[
np
.
random
.
random
(
max_val
[
t
]
+
1
-
min_val
[
t
],
)
for
t
in
range
(
act_sp
.
shape
[
0
])]
)
# print('randon_action = ', random_action)
timestep
=
env
.
step
(
random_action
)
...
...
@@ -99,7 +96,7 @@ class TestCooperativeNavigation:
for
k
,
v
in
obs
.
items
():
assert
v
.
shape
==
env
.
info
().
obs_space
.
shape
[
k
]
for
_
in
range
(
env
.
_max_step
):
action
=
torch
.
randint
(
0
,
5
,
(
num_agent
,
))
action
=
np
.
random
.
randint
(
0
,
5
,
(
num_agent
,
))
timestep
=
env
.
step
(
action
)
obs
=
timestep
.
obs
for
k
,
v
in
obs
.
items
():
...
...
@@ -125,7 +122,7 @@ class TestCooperativeNavigation:
for
k
,
v
in
obs
.
items
():
assert
v
.
shape
==
env
.
info
().
obs_space
.
shape
[
k
]
for
_
in
range
(
env
.
_max_step
):
action
=
torch
.
randn
((
action
=
np
.
random
.
random
((
num_agent
,
5
,
))
...
...
dizoo/overcooked/config/__init__.py
浏览文件 @
4b7e50c4
from
.overcooked_demo_ppo_config
import
overcooked_demo_ppo_config
\ No newline at end of file
from
.overcooked_demo_ppo_config
import
overcooked_demo_ppo_config
dizoo/overcooked/config/overcooked_demo_ppo_config.py
浏览文件 @
4b7e50c4
...
...
@@ -25,7 +25,7 @@ overcooked_league_demo_ppo_config = dict(
value_weight
=
0.5
,
entropy_weight
=
0.01
,
clip_ratio
=
0.2
,
nstep
=
1
,
nstep
=
1
,
nstep_return
=
False
,
adv_norm
=
True
,
value_norm
=
True
,
...
...
dizoo/overcooked/entry/overcooked_selfplay_ppo_main.py
浏览文件 @
4b7e50c4
...
...
@@ -14,6 +14,7 @@ from ding.utils import set_pkg_seed
from
dizoo.overcooked.envs
import
OvercookGameEnv
from
dizoo.overcooked.config
import
overcooked_demo_ppo_config
def
wrapped_overcookgame
():
return
OvercookGameEnv
({})
...
...
dizoo/overcooked/envs/__init__.py
浏览文件 @
4b7e50c4
from
.overcooked_env
import
OvercookEnv
,
OvercookGameEnv
\ No newline at end of file
from
.overcooked_env
import
OvercookEnv
,
OvercookGameEnv
dizoo/overcooked/envs/test_overcooked_env.py
浏览文件 @
4b7e50c4
...
...
@@ -28,7 +28,7 @@ class TestOvercooked:
assert
timestep
.
done
sum_rew
+=
timestep
.
info
[
'final_eval_reward'
][
0
]
print
(
"sum reward is:"
,
sum_rew
)
def
test_overcook_game
(
self
):
concat_obs
=
False
num_agent
=
2
...
...
@@ -42,4 +42,3 @@ class TestOvercooked:
assert
timestep
.
done
print
(
"agent 0 sum reward is:"
,
timestep
.
info
[
0
][
'final_eval_reward'
])
print
(
"agent 1 sum reward is:"
,
timestep
.
info
[
1
][
'final_eval_reward'
])
dizoo/pomdp/envs/atari_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
List
,
Union
,
Sequence
import
copy
import
torch
import
gym
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.utils
import
ENV_REGISTRY
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
.atari_wrappers
import
wrap_deepmind
from
pprint
import
pprint
...
...
@@ -101,7 +100,7 @@ class PomdpAtariEnv(BaseEnv):
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
obs
=
to_ndarray
(
obs
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
@@ -112,8 +111,9 @@ class PomdpAtariEnv(BaseEnv):
info
.
use_wrappers
=
self
.
_make_env
(
only_info
=
True
)
return
info
else
:
raise
NotImplementedError
(
'{} not found in POMDP_INFO_DICT [{}]'
\
.
format
(
self
.
_cfg
.
env_id
,
POMDP_INFO_DICT
.
keys
()))
raise
NotImplementedError
(
'{} not found in POMDP_INFO_DICT [{}]'
.
format
(
self
.
_cfg
.
env_id
,
POMDP_INFO_DICT
.
keys
())
)
# noqa
def
_make_env
(
self
,
only_info
=
False
):
return
wrap_deepmind
(
...
...
dizoo/procgen/coinrun/entry/coinrun_dqn_config.py
浏览文件 @
4b7e50c4
...
...
@@ -11,7 +11,7 @@ coinrun_dqn_default_config = dict(
cuda
=
False
,
on_policy
=
False
,
model
=
dict
(
obs_shape
=
[
3
,
64
,
64
],
obs_shape
=
[
3
,
64
,
64
],
action_shape
=
15
,
encoder_hidden_size_list
=
[
128
,
128
,
512
],
dueling
=
False
,
...
...
@@ -23,9 +23,7 @@ coinrun_dqn_default_config = dict(
learning_rate
=
0.0005
,
target_update_freq
=
500
,
),
collect
=
dict
(
n_sample
=
100
,
),
collect
=
dict
(
n_sample
=
100
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
5000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -46,7 +44,7 @@ coinrun_dqn_create_config = dict(
type
=
'coinrun'
,
import_names
=
[
'dizoo.procgen.coinrun.envs.coinrun_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,),
env_manager
=
dict
(
type
=
'subprocess'
,
),
policy
=
dict
(
type
=
'dqn'
),
)
coinrun_dqn_create_config
=
EasyDict
(
coinrun_dqn_create_config
)
...
...
dizoo/procgen/coinrun/entry/coinrun_ppo_config.py
浏览文件 @
4b7e50c4
...
...
@@ -24,9 +24,7 @@ coinrun_ppo_default_config = dict(
entropy_weight
=
0.01
,
clip_ratio
=
0.2
,
),
collect
=
dict
(
n_sample
=
100
,
),
collect
=
dict
(
n_sample
=
100
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
5000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -48,8 +46,8 @@ coinrun_ppo_create_config = dict(
type
=
'coinrun'
,
import_names
=
[
'dizoo.procgen.coinrun.envs.coinrun_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,),
env_manager
=
dict
(
type
=
'subprocess'
,
),
policy
=
dict
(
type
=
'ppo'
),
)
coinrun_ppo_create_config
=
EasyDict
(
coinrun_ppo_create_config
)
create_config
=
coinrun_ppo_create_config
\ No newline at end of file
create_config
=
coinrun_ppo_create_config
dizoo/procgen/coinrun/envs/coinrun_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
List
,
Union
,
Optional
import
time
import
gym
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
...
...
@@ -32,7 +31,7 @@ class CoinRunEnv(BaseEnv):
self
.
_seed
=
0
self
.
_init_flag
=
False
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
'procgen:procgen-coinrun-v0'
,
start_level
=
0
,
num_levels
=
1
)
self
.
_init_flag
=
True
...
...
@@ -59,18 +58,17 @@ class CoinRunEnv(BaseEnv):
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
)
obs
=
np
.
transpose
(
obs
,
(
2
,
0
,
1
))
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
bool
(
done
),
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
@@ -117,4 +115,3 @@ class CoinRunEnv(BaseEnv):
self
.
_env
=
gym
.
wrappers
.
Monitor
(
self
.
_env
,
self
.
_replay_path
,
video_callable
=
lambda
episode_id
:
True
,
force
=
True
)
dizoo/procgen/maze/entry/maze_dqn_config.py
浏览文件 @
4b7e50c4
...
...
@@ -11,7 +11,7 @@ maze_dqn_default_config = dict(
cuda
=
False
,
on_policy
=
False
,
model
=
dict
(
obs_shape
=
[
3
,
64
,
64
],
obs_shape
=
[
3
,
64
,
64
],
action_shape
=
15
,
encoder_hidden_size_list
=
[
128
,
128
,
512
],
dueling
=
False
,
...
...
@@ -24,9 +24,7 @@ maze_dqn_default_config = dict(
target_update_freq
=
500
,
discount_factor
=
0.99
,
),
collect
=
dict
(
n_sample
=
100
,
),
collect
=
dict
(
n_sample
=
100
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
5000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -37,7 +35,7 @@ maze_dqn_default_config = dict(
),
replay_buffer
=
dict
(
replay_buffer_size
=
100000
,
),
),
cuda
=
True
,
cuda
=
True
,
),
)
maze_dqn_default_config
=
EasyDict
(
maze_dqn_default_config
)
...
...
@@ -48,7 +46,7 @@ maze_dqn_create_config = dict(
type
=
'maze'
,
import_names
=
[
'dizoo.procgen.maze.envs.maze_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,),
env_manager
=
dict
(
type
=
'subprocess'
,
),
policy
=
dict
(
type
=
'dqn'
),
)
maze_dqn_create_config
=
EasyDict
(
maze_dqn_create_config
)
...
...
dizoo/procgen/maze/entry/maze_ppo_config.py
浏览文件 @
4b7e50c4
...
...
@@ -17,7 +17,6 @@ maze_ppo_default_config = dict(
action_shape
=
15
,
encoder_hidden_size_list
=
[
32
,
32
,
64
],
),
learn
=
dict
(
update_per_collect
=
5
,
batch_size
=
64
,
...
...
@@ -26,9 +25,7 @@ maze_ppo_default_config = dict(
clip_ratio
=
0.2
,
learning_rate
=
0.0001
,
),
collect
=
dict
(
n_sample
=
100
,
),
collect
=
dict
(
n_sample
=
100
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
5000
,
)),
other
=
dict
(
eps
=
dict
(
...
...
@@ -49,8 +46,8 @@ maze_ppo_create_config = dict(
type
=
'maze'
,
import_names
=
[
'dizoo.procgen.maze.envs.maze_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
,),
env_manager
=
dict
(
type
=
'subprocess'
,
),
policy
=
dict
(
type
=
'ppo'
),
)
maze_ppo_create_config
=
EasyDict
(
maze_ppo_create_config
)
create_config
=
maze_ppo_create_config
\ No newline at end of file
create_config
=
maze_ppo_create_config
dizoo/procgen/maze/envs/maze_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
List
,
Union
,
Optional
import
time
import
gym
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.utils
import
ENV_REGISTRY
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
,
to_list
from
ding.torch_utils
import
to_ndarray
,
to_list
def
disable_gym_view_window
():
...
...
@@ -32,14 +31,14 @@ class MazeEnv(BaseEnv):
self
.
_seed
=
0
self
.
_init_flag
=
False
def
reset
(
self
)
->
torch
.
Tensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
'procgen:procgen-maze-v0'
,
start_level
=
0
,
num_levels
=
1
)
self
.
_init_flag
=
True
if
hasattr
(
self
,
'_seed'
)
and
hasattr
(
self
,
'_dynamic_seed'
)
and
self
.
_dynamic_seed
:
np_seed
=
100
*
np
.
random
.
randint
(
1
,
1000
)
self
.
_env
.
close
()
self
.
_env
=
gym
.
make
(
'procgen:procgen-maze-v0'
,
start_level
=
self
.
_seed
+
np_seed
,
num_levels
=
1
)
self
.
_env
=
gym
.
make
(
'procgen:procgen-maze-v0'
,
start_level
=
self
.
_seed
+
np_seed
,
num_levels
=
1
)
elif
hasattr
(
self
,
'_seed'
):
self
.
_env
.
close
()
self
.
_env
=
gym
.
make
(
'procgen:procgen-maze-v0'
,
start_level
=
self
.
_seed
,
num_levels
=
1
)
...
...
@@ -62,14 +61,14 @@ class MazeEnv(BaseEnv):
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim
tensor
action
=
action
.
squeeze
()
# 0-dim
array
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
)
obs
=
np
.
transpose
(
obs
,
(
2
,
0
,
1
))
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
bool
(
done
),
info
)
def
info
(
self
)
->
BaseEnvInfo
:
...
...
dizoo/pybullet/envs/pybullet_env.py
浏览文件 @
4b7e50c4
from
typing
import
Any
,
Union
,
List
import
copy
import
torch
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.envs.common.common_function
import
affine_transform
from
ding.torch_utils
import
to_tensor
,
to_ndarray
,
to_list
from
.pybullet_wrappers
import
wrap_pybullet
from
ding.torch_utils
import
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
from
.pybullet_wrappers
import
wrap_pybullet
Pybullet_INFO_DICT
=
{
# pybullet env
...
...
@@ -293,7 +292,7 @@ class PybulletEnv(BaseEnv):
self
.
_use_act_scale
=
cfg
.
use_act_scale
self
.
_init_flag
=
False
def
reset
(
self
)
->
torch
.
FloatTensor
:
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
self
.
_make_env
(
only_info
=
False
)
self
.
_init_flag
=
True
...
...
@@ -317,7 +316,7 @@ class PybulletEnv(BaseEnv):
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
def
step
(
self
,
action
:
Union
[
np
.
ndarray
,
list
])
->
BaseEnvTimestep
:
action
=
to_ndarray
(
action
)
if
self
.
_use_act_scale
:
action_range
=
self
.
info
().
act_space
.
value
...
...
@@ -325,7 +324,7 @@ class PybulletEnv(BaseEnv):
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
self
.
_final_eval_reward
+=
rew
obs
=
to_ndarray
(
obs
).
astype
(
'float32'
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
Tensor
with shape (1,)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a
array
with shape (1,)
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
...
...
dizoo/pybullet/envs/pybullet_wrappers.py
浏览文件 @
4b7e50c4
...
...
@@ -10,7 +10,6 @@ except ImportError:
logging
.
warning
(
"not found pybullet env, please install it, refer to https://github.com/benelot/pybullet-gym"
)
def
wrap_pybullet
(
env_id
,
norm_obs
=
True
,
norm_reward
=
True
,
only_info
=
False
)
->
gym
.
Env
:
r
"""
Overview:
...
...
dizoo/slime_volley/envs/slime_volley_env.py
浏览文件 @
4b7e50c4
...
...
@@ -8,7 +8,7 @@ import slimevolleygym
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.utils
import
ENV_REGISTRY
from
ding.torch_utils
import
to_
tensor
,
to_
ndarray
from
ding.torch_utils
import
to_ndarray
class
GymSelfPlayMonitor
(
gym
.
wrappers
.
Monitor
):
...
...
@@ -57,9 +57,9 @@ class SlimeVolleyEnv(BaseEnv):
assert
isinstance
(
action1
,
np
.
ndarray
),
type
(
action1
)
assert
action2
is
None
or
isinstance
(
action1
,
np
.
ndarray
),
type
(
action2
)
if
action1
.
shape
==
(
1
,
):
action1
=
action1
.
squeeze
()
# 0-dim
tensor
action1
=
action1
.
squeeze
()
# 0-dim
array
if
action2
is
not
None
and
action2
.
shape
==
(
1
,
):
action2
=
action2
.
squeeze
()
# 0-dim
tensor
action2
=
action2
.
squeeze
()
# 0-dim
array
action1
=
SlimeVolleyEnv
.
_process_action
(
action1
)
action2
=
SlimeVolleyEnv
.
_process_action
(
action2
)
obs1
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action1
,
action2
)
...
...
dizoo/smac/config/smac_3s5z_wqmix_config.py
浏览文件 @
4b7e50c4
...
...
@@ -7,7 +7,7 @@ collector_env_num = 16
evaluator_env_num
=
8
main_config
=
dict
(
exp_name
=
'3s5s_wqmix_ow_ff3-256_hsl64'
,
exp_name
=
'3s5s_wqmix_ow_ff3-256_hsl64'
,
env
=
dict
(
map_name
=
'3s5z'
,
difficulty
=
7
,
...
...
dizoo/smac/config/smac_5m6m_wqmix_config.py
浏览文件 @
4b7e50c4
...
...
@@ -7,7 +7,7 @@ collector_env_num = 16
evaluator_env_num
=
8
main_config
=
dict
(
exp_name
=
'5m6m_wqmix_ow_ff3-256_hsl64'
,
exp_name
=
'5m6m_wqmix_ow_ff3-256_hsl64'
,
env
=
dict
(
map_name
=
'5m_vs_6m'
,
difficulty
=
7
,
...
...
dizoo/smac/config/smac_MMM2_wqmix_config.py
浏览文件 @
4b7e50c4
...
...
@@ -7,7 +7,7 @@ collector_env_num = 16
evaluator_env_num
=
8
main_config
=
dict
(
exp_name
=
'MMM2_wqmix_ow_ff3-256_hsl64'
,
exp_name
=
'MMM2_wqmix_ow_ff3-256_hsl64'
,
env
=
dict
(
map_name
=
'MMM2'
,
difficulty
=
7
,
...
...
dizoo/smac/config/smac_MMM_wqmix_config.py
浏览文件 @
4b7e50c4
...
...
@@ -7,7 +7,7 @@ collector_env_num = 16
evaluator_env_num
=
8
main_config
=
dict
(
exp_name
=
'MMM_wqmix_ow_ff3-256_hsl64'
,
exp_name
=
'MMM_wqmix_ow_ff3-256_hsl64'
,
env
=
dict
(
map_name
=
'MMM'
,
difficulty
=
7
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录