From 490691fbe25b0f5ae55601bb1790d399fff7bdd0 Mon Sep 17 00:00:00 2001 From: Swain Date: Mon, 13 Dec 2021 15:45:16 +0800 Subject: [PATCH] feature(nyz): add delay reward mujoco env (#145) * feature(nyz): add delay reward mujoco env * test(nyz): add delay reward mujoco env test and fix bug --- dizoo/mujoco/envs/mujoco_env.py | 33 ++++++++++++++++++++++++++-- dizoo/mujoco/envs/test_mujoco_env.py | 17 ++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 dizoo/mujoco/envs/test_mujoco_env.py diff --git a/dizoo/mujoco/envs/mujoco_env.py b/dizoo/mujoco/envs/mujoco_env.py index 0da502d..5387f93 100644 --- a/dizoo/mujoco/envs/mujoco_env.py +++ b/dizoo/mujoco/envs/mujoco_env.py @@ -1,13 +1,14 @@ from typing import Any, Union, List import copy import numpy as np +from easydict import EasyDict from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape from ding.envs.common.env_element import EnvElement, EnvElementInfo from ding.envs.common.common_function import affine_transform from ding.torch_utils import to_ndarray, to_list -from .mujoco_wrappers import wrap_mujoco from ding.utils import ENV_REGISTRY +from .mujoco_wrappers import wrap_mujoco MUJOCO_INFO_DICT = { 'Ant-v3': BaseEnvInfo( @@ -259,9 +260,24 @@ MUJOCO_INFO_DICT = { @ENV_REGISTRY.register('mujoco') class MujocoEnv(BaseEnv): + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + config = dict( + use_act_scale=False, + delay_reward_step=0, + ) + def __init__(self, cfg: dict) -> None: self._cfg = cfg self._use_act_scale = cfg.use_act_scale + if 'delay_reward_step' in cfg: + self._delay_reward_step = cfg.delay_reward_step + else: + self._delay_reward_step = self.default_config().delay_reward_step self._init_flag = False def reset(self) -> np.ndarray: @@ -276,6 +292,9 @@ class MujocoEnv(BaseEnv): obs = self._env.reset() obs = to_ndarray(obs).astype('float32') self._final_eval_reward = 0. + if self._delay_reward_step > 1: + self._delay_reward_duration = 0 + self._current_delay_reward = 0. return obs def close(self) -> None: @@ -296,7 +315,17 @@ class MujocoEnv(BaseEnv): obs, rew, done, info = self._env.step(action) self._final_eval_reward += rew obs = to_ndarray(obs).astype('float32') - rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,) + if self._delay_reward_step > 1: + self._current_delay_reward += rew + self._delay_reward_duration += 1 + if done or self._delay_reward_duration >= self._delay_reward_step: + rew = to_ndarray([self._current_delay_reward]) + self._current_delay_reward = 0. + self._delay_reward_duration = 0 + else: + rew = to_ndarray([0.]) + else: + rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,) if done: info['final_eval_reward'] = self._final_eval_reward return BaseEnvTimestep(obs, rew, done, info) diff --git a/dizoo/mujoco/envs/test_mujoco_env.py b/dizoo/mujoco/envs/test_mujoco_env.py new file mode 100644 index 0000000..5d77ab9 --- /dev/null +++ b/dizoo/mujoco/envs/test_mujoco_env.py @@ -0,0 +1,17 @@ +import pytest +import numpy as np +from easydict import EasyDict +from dizoo.mujoco.envs import MujocoEnv + + +@pytest.mark.envtest +@pytest.mark.parametrize('delay_reward_step', [0, 10]) +def test_mujoco_env(delay_reward_step): + env = MujocoEnv(EasyDict({'env_id': 'Ant-v3', 'use_act_scale': False, 'delay_reward_step': delay_reward_step})) + env.reset() + action_dim = env.info().act_space.shape + for _ in range(25): + action = np.random.random(size=action_dim) + timestep = env.step(action) + print(_, timestep.reward) + assert timestep.reward.shape == (1, ) -- GitLab