提交 b4ccbad3 编写于 作者: N niuyazhe

feature(nyz): add naive minigrid env

上级 414b5305
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
import gym
import copy
from easydict import EasyDict
from namedlist import namedlist
from collections import namedtuple
......@@ -16,10 +17,15 @@ class BaseEnv(ABC, gym.Env):
basic environment class, extended from ``gym.Env``
Interface:
``__init__``, ``reset``, ``close``, ``step``, ``info``, ``create_collector_env_cfg``, \
``create_evaluator_env_cfg``,
``enable_save_replay``
``create_evaluator_env_cfg``, ``enable_save_replay``, ``default_config``
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
@abstractmethod
def __init__(self, cfg: dict) -> None:
"""
......
from .minigrid_env import MiniGridEnv
from typing import Any, List, Union, Optional
from collections import namedtuple
import time
import gym
import numpy as np
from gym_minigrid.wrappers import FlatObsWrapper, RGBImgPartialObsWrapper, ImgObsWrapper
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.torch_utils import to_tensor, to_ndarray, to_list
from ding.utils import ENV_REGISTRY
MINIGRID_INFO_DICT = {
'MiniGrid-Empty-8x8-v0': BaseEnvInfo(
agent_num=1,
obs_space=EnvElementInfo(shape=(2739, ), value={
'min': 0,
'max': 5,
'dtype': np.float32
}),
act_space=EnvElementInfo(shape=(1, ), value={
'min': 0,
'max': 7,
'dtype': np.int64,
}),
rew_space=EnvElementInfo(shape=(1, ), value={
'min': 0,
'max': 1,
'dtype': np.float32
}),
use_wrappers=None,
),
}
@ENV_REGISTRY.register('minigrid')
class MiniGridEnv(BaseEnv):
config = dict(
env_id='MiniGrid-Empty-8x8-v0',
flat_obs=True,
)
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._init_flag = False
self._env_id = cfg.env_id
self._flat_obs = cfg.flat_obs
def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make(self._env_id)
if self._flat_obs:
self._env = FlatObsWrapper(self._env)
# self._env = RGBImgPartialObsWrapper(self._env)
# self._env = ImgObsWrapper(self._env)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
self._final_eval_reward = 0
obs = self._env.reset()
obs = to_ndarray(obs).astype(np.float32)
return obs
def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False
def render(self) -> None:
self._env.render()
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def step(self, action: np.ndarray) -> BaseEnvTimestep:
assert isinstance(action, np.ndarray), type(action)
if action.shape == (1, ):
action = action.squeeze() # 0-dim tensor
obs, rew, done, info = self._env.step(action)
rew = float(rew)
self._final_eval_reward += rew
if done:
info['final_eval_reward'] = self._final_eval_reward
obs = to_ndarray(obs).astype(np.float32)
rew = to_ndarray([rew]) # wrapped to be transfered to a Tensor with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)
def info(self) -> BaseEnvInfo:
return MINIGRID_INFO_DICT[self._env_id]
def __repr__(self) -> str:
return "DI-engine MiniGrid Env"
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
raise NotImplementedError
import pytest
import numpy as np
from dizoo.minigrid.envs import MiniGridEnv
@pytest.mark.unittest
class TestMiniGridEnv:
def test_naive(self):
env = MiniGridEnv(MiniGridEnv.default_config())
env.seed(314)
assert env._seed == 314
obs = env.reset()
act_val = env.info().act_space.value
min_val, max_val = act_val['min'], act_val['max']
for i in range(10):
random_action = np.random.randint(min_val, max_val, size=(1, ))
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (2739, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.info().rew_space.value['min']
assert timestep.reward <= env.info().rew_space.value['max']
print(env.info())
env.close()
......@@ -105,6 +105,9 @@ setup(
'procgen_env': [
'procgen',
],
'minigrid_env': [
'gym-minigrid',
],
'sc2_env': [
'absl-py>=0.1.0',
'future',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册