未验证 提交 8f47f4cb 编写于 作者: K Ke Li 提交者: GitHub

feature(lk): add gym-soccer (HFO) env (#94)

* add_soccer_env

* add_info

* close

* format

* test_gym_soccer

* rm_torch

* replay_log

* format_style

* add_gym_soccer_to_readme

* separate render_func

* add_gif_file

* scale_action

* flake_style_format

* resolve_review_comments

* add branch info for gym hybrid
上级 f04b9eb7
......@@ -171,6 +171,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification) |
| 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley) |
| 21 | [gym_bybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_bybrid) |
| 22 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer) |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
......
......@@ -57,7 +57,7 @@ class GymHybridEnv(BaseEnv):
if done:
info['final_eval_reward'] = self._final_eval_reward
obs = to_ndarray(obs).astype(np.float32)
rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
rew = to_ndarray([rew]) # wrapped to be transfered to a numpy array with shape (1,)
info['action_args_mask'] = np.array([[1, 0], [0, 1], [0, 0]])
return BaseEnvTimestep(obs, rew, done, info)
......
# How to replay a log
1. Set the log path to store episode logs by the following command:
`env.enable_save_replay('./game_log')`
2. After running the game, you can see some log files in the game_log directory.
3. Execute the following command to replay the log file (*.rcg)
` env.replay_log("game_log/20211019011053-base_left_0-vs-base_right_0.rcg")`
\ No newline at end of file
import sys
from typing import Any, List, Optional, Union
import gym
import gym_soccer
import numpy as np
from ding.envs import BaseEnv, BaseEnvInfo, BaseEnvTimestep
from ding.envs.common.common_function import affine_transform
from ding.envs.common.env_element import EnvElementInfo
from ding.torch_utils import to_list, to_ndarray, to_tensor
from ding.utils import ENV_REGISTRY
from gym.utils import seeding
@ENV_REGISTRY.register('gym_soccer')
class GymSoccerEnv(BaseEnv):
default_env_id = ['Soccer-v0', 'SoccerEmptyGoal-v0', 'SoccerAgainstKeeper-v0']
def __init__(self, cfg: dict = {}) -> None:
self._cfg = cfg
self._act_scale = cfg.act_scale
self._env_id = cfg.env_id
assert self._env_id in self.default_env_id
self._init_flag = False
self._replay_path = None
def reset(self) -> np.array:
if not self._init_flag:
self._env = gym.make(self._env_id, replay_path=self._replay_path)
self._init_flag = True
self._final_eval_reward = 0
obs = self._env.reset()
obs = to_ndarray(obs).astype(np.float32)
return obs
def step(self, action: List) -> BaseEnvTimestep:
if self._act_scale:
# The continuous action is a Tensor of size = (1,)
# We indexed at [0] to fetch it as a scalar value
action[1][0] = affine_transform(action[1][0], min_val=0, max_val=100)
action[2][0] = affine_transform(action[2][0], min_val=-180, max_val=180)
action[3][0] = affine_transform(action[3][0], min_val=-180, max_val=180)
action[4][0] = affine_transform(action[4][0], min_val=0, max_val=100)
action[5][0] = affine_transform(action[5][0], min_val=-180, max_val=180)
obs, rew, done, info = self._env.step(action)
self._final_eval_reward += rew
if done:
info['final_eval_reward'] = self._final_eval_reward
obs = to_ndarray(obs).astype(np.float32)
# reward wrapped to be transfered to a numpy array with shape (1,)
rew = to_ndarray([rew])
# '1' indicates the discrete action is associated with the continuous parameters
info['action_args_mask'] = np.array([[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 1]])
return BaseEnvTimestep(obs, rew, done, info)
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def close(self) -> None:
self._init_flag = False
def get_random_action(self):
# discrete action type: 0, 1, 2
# continuous action_args:
# - power: [0, 100]
# - direction: [-180, 180]
# the action space is (6,), the first indicates discrete action and the remaining indicates continuous action
# discrete action 0 assotiated with the first and second continuous parameters
# discrete action 1 assotiated with the third continuous parameter
# discrete action 2 assotiated with the forth and fifth continuous parameters
return self._env.action_space.sample()
def info(self) -> BaseEnvInfo:
T = EnvElementInfo
return BaseEnvInfo(
agent_num=1,
obs_space=T(
(59, ),
{
# [min, max]
'min': -1,
'max': 1,
'dtype': np.float32,
},
),
act_space=T(
# the discrete action shape is (3,)
# however, the continuous action shape is (5,), which is not revealed in the info
(
3,
),
{
# [min, max)
'min': 0,
'max': 3,
'dtype': int,
},
),
rew_space=T(
(1, ),
{
# [min, max)
'min': 0,
'max': 2.0,
'dtype': int,
},
),
use_wrappers=None,
)
def render(self, close=False):
self._env.render(close)
def __repr__(self) -> str:
return "DI-engine gym soccer Env"
def replay_log(self, log_path):
self._env.replay_log(log_path)
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './game_log'
self._replay_path = replay_path
import numpy as np
import pytest
from dizoo.gym_soccer.envs.gym_soccer_env import GymSoccerEnv
from easydict import EasyDict
@pytest.mark.envtest
class TestGymSoccerEnv:
def test_naive(self):
env = GymSoccerEnv(EasyDict({'env_id': 'Soccer-v0'}))
env.enable_save_replay('./video')
env.seed(25, dynamic_seed=False)
assert env._seed == 25
obs = env.reset()
assert obs.shape == (59, )
for i in range(1000):
random_action = env.get_random_action()
# print('random_action', random_action)
timestep = env.step(random_action)
env.render()
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (59, )
# print(timestep.obs)
assert timestep.reward.shape == (1, )
assert timestep.info['action_args_mask'].shape == (3, 5)
if timestep.done:
print('reset env')
env.reset()
assert env._final_eval_reward == 0
print(env.info())
# env.replay_log("./video/20211019011053-base_left_0-vs-base_right_0.rcg")
env.close()
......@@ -124,8 +124,12 @@ setup(
# 'pybulletgym @ git+https://github.com/benelot/pybullet-gym@master#egg=pybulletgym',
# ],
# 'gym_hybrid_env': [
# 'gym-hybrid @ git+https://github.com/thomashirtz/gym-hybrid#egg=gym-hybrid',
# 'gym-hybrid @ git+https://github.com/thomashirtz/gym-hybrid@master#egg=gym-hybrid',
# ],
# 'gym_soccer_env': [
# 'gym-soccer @ git+https://github.com/LikeJulia/gym-soccer@dev-install-packages#egg=gym-soccer',
# ],
'sc2_env': [
'absl-py>=0.1.0',
'future',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册