未验证 提交 490691fb 编写于 作者: S Swain 提交者: GitHub

feature(nyz): add delay reward mujoco env (#145)

* feature(nyz): add delay reward mujoco env

* test(nyz): add delay reward mujoco env test and fix bug
上级 bc0102ba
from typing import Any, Union, List
import copy
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray, to_list
from .mujoco_wrappers import wrap_mujoco
from ding.utils import ENV_REGISTRY
from .mujoco_wrappers import wrap_mujoco
MUJOCO_INFO_DICT = {
'Ant-v3': BaseEnvInfo(
......@@ -259,9 +260,24 @@ MUJOCO_INFO_DICT = {
@ENV_REGISTRY.register('mujoco')
class MujocoEnv(BaseEnv):
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
use_act_scale=False,
delay_reward_step=0,
)
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._use_act_scale = cfg.use_act_scale
if 'delay_reward_step' in cfg:
self._delay_reward_step = cfg.delay_reward_step
else:
self._delay_reward_step = self.default_config().delay_reward_step
self._init_flag = False
def reset(self) -> np.ndarray:
......@@ -276,6 +292,9 @@ class MujocoEnv(BaseEnv):
obs = self._env.reset()
obs = to_ndarray(obs).astype('float32')
self._final_eval_reward = 0.
if self._delay_reward_step > 1:
self._delay_reward_duration = 0
self._current_delay_reward = 0.
return obs
def close(self) -> None:
......@@ -296,7 +315,17 @@ class MujocoEnv(BaseEnv):
obs, rew, done, info = self._env.step(action)
self._final_eval_reward += rew
obs = to_ndarray(obs).astype('float32')
rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
if self._delay_reward_step > 1:
self._current_delay_reward += rew
self._delay_reward_duration += 1
if done or self._delay_reward_duration >= self._delay_reward_step:
rew = to_ndarray([self._current_delay_reward])
self._current_delay_reward = 0.
self._delay_reward_duration = 0
else:
rew = to_ndarray([0.])
else:
rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
if done:
info['final_eval_reward'] = self._final_eval_reward
return BaseEnvTimestep(obs, rew, done, info)
......
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.mujoco.envs import MujocoEnv
@pytest.mark.envtest
@pytest.mark.parametrize('delay_reward_step', [0, 10])
def test_mujoco_env(delay_reward_step):
env = MujocoEnv(EasyDict({'env_id': 'Ant-v3', 'use_act_scale': False, 'delay_reward_step': delay_reward_step}))
env.reset()
action_dim = env.info().act_space.shape
for _ in range(25):
action = np.random.random(size=action_dim)
timestep = env.step(action)
print(_, timestep.reward)
assert timestep.reward.shape == (1, )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册