Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
b4ccbad3
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
59
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b4ccbad3
编写于
7月 19, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add naive minigrid env
上级
414b5305
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
142 addition
and
2 deletion
+142
-2
ding/envs/env/base_env.py
ding/envs/env/base_env.py
+8
-2
dizoo/minigrid/__init__.py
dizoo/minigrid/__init__.py
+0
-0
dizoo/minigrid/envs/__init__.py
dizoo/minigrid/envs/__init__.py
+1
-0
dizoo/minigrid/envs/minigrid_env.py
dizoo/minigrid/envs/minigrid_env.py
+103
-0
dizoo/minigrid/envs/test_minigrid_env.py
dizoo/minigrid/envs/test_minigrid_env.py
+27
-0
setup.py
setup.py
+3
-0
未找到文件。
ding/envs/env/base_env.py
浏览文件 @
b4ccbad3
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
List
,
Tuple
import
gym
import
copy
from
easydict
import
EasyDict
from
namedlist
import
namedlist
from
collections
import
namedtuple
...
...
@@ -16,10 +17,15 @@ class BaseEnv(ABC, gym.Env):
basic environment class, extended from ``gym.Env``
Interface:
``__init__``, ``reset``, ``close``, ``step``, ``info``, ``create_collector_env_cfg``,
\
``create_evaluator_env_cfg``,
``enable_save_replay``
``create_evaluator_env_cfg``, ``enable_save_replay``, ``default_config``
"""
@
classmethod
def
default_config
(
cls
:
type
)
->
EasyDict
:
cfg
=
EasyDict
(
copy
.
deepcopy
(
cls
.
config
))
cfg
.
cfg_type
=
cls
.
__name__
+
'Dict'
return
cfg
@
abstractmethod
def
__init__
(
self
,
cfg
:
dict
)
->
None
:
"""
...
...
dizoo/minigrid/__init__.py
0 → 100644
浏览文件 @
b4ccbad3
dizoo/minigrid/envs/__init__.py
0 → 100644
浏览文件 @
b4ccbad3
from
.minigrid_env
import
MiniGridEnv
dizoo/minigrid/envs/minigrid_env.py
0 → 100644
浏览文件 @
b4ccbad3
from
typing
import
Any
,
List
,
Union
,
Optional
from
collections
import
namedtuple
import
time
import
gym
import
numpy
as
np
from
gym_minigrid.wrappers
import
FlatObsWrapper
,
RGBImgPartialObsWrapper
,
ImgObsWrapper
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
from
ding.envs.common.env_element
import
EnvElement
,
EnvElementInfo
from
ding.torch_utils
import
to_tensor
,
to_ndarray
,
to_list
from
ding.utils
import
ENV_REGISTRY
MINIGRID_INFO_DICT
=
{
'MiniGrid-Empty-8x8-v0'
:
BaseEnvInfo
(
agent_num
=
1
,
obs_space
=
EnvElementInfo
(
shape
=
(
2739
,
),
value
=
{
'min'
:
0
,
'max'
:
5
,
'dtype'
:
np
.
float32
}),
act_space
=
EnvElementInfo
(
shape
=
(
1
,
),
value
=
{
'min'
:
0
,
'max'
:
7
,
'dtype'
:
np
.
int64
,
}),
rew_space
=
EnvElementInfo
(
shape
=
(
1
,
),
value
=
{
'min'
:
0
,
'max'
:
1
,
'dtype'
:
np
.
float32
}),
use_wrappers
=
None
,
),
}
@
ENV_REGISTRY
.
register
(
'minigrid'
)
class
MiniGridEnv
(
BaseEnv
):
config
=
dict
(
env_id
=
'MiniGrid-Empty-8x8-v0'
,
flat_obs
=
True
,
)
def
__init__
(
self
,
cfg
:
dict
)
->
None
:
self
.
_cfg
=
cfg
self
.
_init_flag
=
False
self
.
_env_id
=
cfg
.
env_id
self
.
_flat_obs
=
cfg
.
flat_obs
def
reset
(
self
)
->
np
.
ndarray
:
if
not
self
.
_init_flag
:
self
.
_env
=
gym
.
make
(
self
.
_env_id
)
if
self
.
_flat_obs
:
self
.
_env
=
FlatObsWrapper
(
self
.
_env
)
# self._env = RGBImgPartialObsWrapper(self._env)
# self._env = ImgObsWrapper(self._env)
self
.
_init_flag
=
True
if
hasattr
(
self
,
'_seed'
)
and
hasattr
(
self
,
'_dynamic_seed'
)
and
self
.
_dynamic_seed
:
np_seed
=
100
*
np
.
random
.
randint
(
1
,
1000
)
self
.
_env
.
seed
(
self
.
_seed
+
np_seed
)
elif
hasattr
(
self
,
'_seed'
):
self
.
_env
.
seed
(
self
.
_seed
)
self
.
_final_eval_reward
=
0
obs
=
self
.
_env
.
reset
()
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
return
obs
def
close
(
self
)
->
None
:
if
self
.
_init_flag
:
self
.
_env
.
close
()
self
.
_init_flag
=
False
def
render
(
self
)
->
None
:
self
.
_env
.
render
()
def
seed
(
self
,
seed
:
int
,
dynamic_seed
:
bool
=
True
)
->
None
:
self
.
_seed
=
seed
self
.
_dynamic_seed
=
dynamic_seed
np
.
random
.
seed
(
self
.
_seed
)
def
step
(
self
,
action
:
np
.
ndarray
)
->
BaseEnvTimestep
:
assert
isinstance
(
action
,
np
.
ndarray
),
type
(
action
)
if
action
.
shape
==
(
1
,
):
action
=
action
.
squeeze
()
# 0-dim tensor
obs
,
rew
,
done
,
info
=
self
.
_env
.
step
(
action
)
rew
=
float
(
rew
)
self
.
_final_eval_reward
+=
rew
if
done
:
info
[
'final_eval_reward'
]
=
self
.
_final_eval_reward
obs
=
to_ndarray
(
obs
).
astype
(
np
.
float32
)
rew
=
to_ndarray
([
rew
])
# wrapped to be transfered to a Tensor with shape (1,)
return
BaseEnvTimestep
(
obs
,
rew
,
done
,
info
)
def
info
(
self
)
->
BaseEnvInfo
:
return
MINIGRID_INFO_DICT
[
self
.
_env_id
]
def
__repr__
(
self
)
->
str
:
return
"DI-engine MiniGrid Env"
def
enable_save_replay
(
self
,
replay_path
:
Optional
[
str
]
=
None
)
->
None
:
if
replay_path
is
None
:
replay_path
=
'./video'
self
.
_replay_path
=
replay_path
raise
NotImplementedError
dizoo/minigrid/envs/test_minigrid_env.py
0 → 100644
浏览文件 @
b4ccbad3
import
pytest
import
numpy
as
np
from
dizoo.minigrid.envs
import
MiniGridEnv
@
pytest
.
mark
.
unittest
class
TestMiniGridEnv
:
def
test_naive
(
self
):
env
=
MiniGridEnv
(
MiniGridEnv
.
default_config
())
env
.
seed
(
314
)
assert
env
.
_seed
==
314
obs
=
env
.
reset
()
act_val
=
env
.
info
().
act_space
.
value
min_val
,
max_val
=
act_val
[
'min'
],
act_val
[
'max'
]
for
i
in
range
(
10
):
random_action
=
np
.
random
.
randint
(
min_val
,
max_val
,
size
=
(
1
,
))
timestep
=
env
.
step
(
random_action
)
print
(
timestep
)
assert
isinstance
(
timestep
.
obs
,
np
.
ndarray
)
assert
isinstance
(
timestep
.
done
,
bool
)
assert
timestep
.
obs
.
shape
==
(
2739
,
)
assert
timestep
.
reward
.
shape
==
(
1
,
)
assert
timestep
.
reward
>=
env
.
info
().
rew_space
.
value
[
'min'
]
assert
timestep
.
reward
<=
env
.
info
().
rew_space
.
value
[
'max'
]
print
(
env
.
info
())
env
.
close
()
setup.py
浏览文件 @
b4ccbad3
...
...
@@ -105,6 +105,9 @@ setup(
'procgen_env'
:
[
'procgen'
,
],
'minigrid_env'
:
[
'gym-minigrid'
,
],
'sc2_env'
:
[
'absl-py>=0.1.0'
,
'future'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录