未验证 提交 dbf432cd 编写于 作者: L LuciusMos 提交者: GitHub

feature(zlx): add vs bot training and self-play training with slime volley env (#23)

* slime volley env in dizoo, first commit

* fix bug in slime volley env

* modify volley env to satisfy ding 1v1 requirements; add naive self-play and league training pipeline(evaluator is not finished, now use a very naive one)

* adopt volley builtin ai as default eval opponent

* polish(nyz): polish slime_volley_env and its test

* feature(nyz): add slime_volley vs bot ppo demo

* feature(nyz): add battle_sample_serial_collector and adapt abnormal check in subprocess env manager

* feature(nyz): add slime volley self-play demo

* style(nyz): add slime_volleyball env gif and split MARL and selfplay label

* feature(nyz): add save replay function in slime volleyball env
Co-authored-by: Nzlx-sensetime <zhaoliangxuan@sensetime.com>
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 c500a2e5
...@@ -154,20 +154,21 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` ...@@ -154,20 +154,21 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs) | | 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs) |
| 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs) | | 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs) |
| 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs) | | 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs) |
| 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control) | | 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control) |
| 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs) | | 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs) |
| 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs) | | 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs) |
| 9 | [mujoco](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs) | | 9 | [mujoco](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs) |
| 10 | [multiagent_particle](https://github.com/openai/multiagent-particle-envs) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/multiagent_particle/multiagent_particle.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_particle/envs) | | 10 | [multiagent_particle](https://github.com/openai/multiagent-particle-envs) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/multiagent_particle/multiagent_particle.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_particle/envs) |
| 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs) | | 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs) |
| 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen) | | 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen) |
| 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs) | | 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs) |
| 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs) | | 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs) |
| 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl) | | 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl) |
| 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) | | 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) |
| 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) | | 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) |
| 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs) | | 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs) |
| 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification) | | 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification) |
| 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley) |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
...@@ -181,6 +182,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` ...@@ -181,6 +182,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset ![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset
![selfplay](https://img.shields.io/badge/-selfplay-blue) means environment that allows agent VS agent battle
P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also sparse reward type P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also sparse reward type
## Contribution ## Contribution
......
...@@ -32,6 +32,15 @@ _NTYPE_TO_CTYPE = { ...@@ -32,6 +32,15 @@ _NTYPE_TO_CTYPE = {
} }
def is_abnormal_timestep(timestep: namedtuple) -> bool:
if isinstance(timestep.info, dict):
return timestep.info.get('abnormal', False)
elif isinstance(timestep.info, list) or isinstance(timestep.info, tuple):
return timestep.info[0].get('abnormal', False) or timestep.info[1].get('abnormal', False)
else:
raise TypeError("invalid env timestep type: {}".format(type(timestep.info)))
class ShmBuffer(): class ShmBuffer():
""" """
Overview: Overview:
...@@ -452,7 +461,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager): ...@@ -452,7 +461,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())
for env_id, timestep in timesteps.items(): for env_id, timestep in timesteps.items():
if timestep.info.get('abnormal', False): if is_abnormal_timestep(timestep):
self._env_states[env_id] = EnvState.ERROR self._env_states[env_id] = EnvState.ERROR
continue continue
if timestep.done: if timestep.done:
...@@ -497,7 +506,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager): ...@@ -497,7 +506,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
elif cmd in method_name_list: elif cmd in method_name_list:
if cmd == 'step': if cmd == 'step':
timestep = env.step(*args, **kwargs) timestep = env.step(*args, **kwargs)
if timestep.info.get('abnormal', False): if is_abnormal_timestep(timestep):
ret = timestep ret = timestep
else: else:
if obs_buffer is not None: if obs_buffer is not None:
...@@ -554,7 +563,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager): ...@@ -554,7 +563,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
@timeout_wrapper(timeout=step_timeout) @timeout_wrapper(timeout=step_timeout)
def step_fn(*args, **kwargs): def step_fn(*args, **kwargs):
timestep = env.step(*args, **kwargs) timestep = env.step(*args, **kwargs)
if timestep.info.get('abnormal', False): if is_abnormal_timestep(timestep):
ret = timestep ret = timestep
else: else:
if obs_buffer is not None: if obs_buffer is not None:
...@@ -768,7 +777,7 @@ class SyncSubprocessEnvManager(AsyncSubprocessEnvManager): ...@@ -768,7 +777,7 @@ class SyncSubprocessEnvManager(AsyncSubprocessEnvManager):
for i, (env_id, timestep) in enumerate(timesteps.items()): for i, (env_id, timestep) in enumerate(timesteps.items()):
timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())
for env_id, timestep in timesteps.items(): for env_id, timestep in timesteps.items():
if timestep.info.get('abnormal', False): if is_abnormal_timestep(timestep):
self._env_states[env_id] = EnvState.ERROR self._env_states[env_id] = EnvState.ERROR
continue continue
if timestep.done: if timestep.done:
......
...@@ -121,8 +121,8 @@ class VAC(nn.Module): ...@@ -121,8 +121,8 @@ class VAC(nn.Module):
else: else:
self.actor = [self.actor_encoder, self.actor_head] self.actor = [self.actor_encoder, self.actor_head]
self.critic = [self.critic_encoder, self.critic_head] self.critic = [self.critic_encoder, self.critic_head]
# for convenience of call some apis(such as: self.critic.parameters()), but may cause # Convenient for calling some apis (e.g. self.critic.parameters()),
# misunderstanding when print(self) # but may cause misunderstanding when `print(self)`
self.actor = nn.ModuleList(self.actor) self.actor = nn.ModuleList(self.actor)
self.critic = nn.ModuleList(self.critic) self.critic = nn.ModuleList(self.critic)
......
...@@ -4,6 +4,7 @@ from .base_serial_collector import ISerialCollector, create_serial_collector, ge ...@@ -4,6 +4,7 @@ from .base_serial_collector import ISerialCollector, create_serial_collector, ge
from .sample_serial_collector import SampleSerialCollector from .sample_serial_collector import SampleSerialCollector
from .episode_serial_collector import EpisodeSerialCollector from .episode_serial_collector import EpisodeSerialCollector
from .battle_episode_serial_collector import BattleEpisodeSerialCollector from .battle_episode_serial_collector import BattleEpisodeSerialCollector
from .battle_sample_serial_collector import BattleSampleSerialCollector
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
from .interaction_serial_evaluator import InteractionSerialEvaluator from .interaction_serial_evaluator import InteractionSerialEvaluator
......
from typing import Optional, Any, List, Tuple
from collections import namedtuple, deque
from easydict import EasyDict
import numpy as np
import torch
from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists, one_time_warning
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
@SERIAL_COLLECTOR_REGISTRY.register('sample_1v1')
class BattleSampleSerialCollector(ISerialCollector):
"""
Overview:
Sample collector(n_sample) with two policy battle
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
Property:
envstep
"""
config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100)
def __init__(
self,
cfg: EasyDict,
env: BaseEnvManager = None,
policy: List[namedtuple] = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'collector'
) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`EasyDict`): Config dict
- env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
- policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy
- tb_logger (:obj:`SummaryWriter`): tensorboard handle
"""
self._exp_name = exp_name
self._instance_name = instance_name
self._collect_print_freq = cfg.collect_print_freq
self._deepcopy_obs = cfg.deepcopy_obs
self._transform_obs = cfg.transform_obs
self._cfg = cfg
self._timer = EasyTimer()
self._end_flag = False
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self._traj_len = float("inf")
self.reset(policy, env)
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self._env = _env
self._env.launch()
self._env_num = self._env.env_num
else:
self._env.reset()
def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None:
"""
Overview:
Reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
assert len(_policy) == 2, "1v1 sample collector needs 2 policy, but found {}".format(len(_policy))
self._policy = _policy
self._default_n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None)
self._unroll_len = _policy[0].get_attribute('unroll_len')
self._on_policy = _policy[0].get_attribute('cfg').on_policy
if self._default_n_sample is not None:
self._traj_len = max(
self._unroll_len,
self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0)
)
self._logger.debug(
'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format(
self._default_n_sample, self._env_num, self._traj_len
)
)
else:
self._traj_len = INF
for p in self._policy:
p.reset()
def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment and policy.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
self._policy_output_pool = CachePool('policy_output', self._env_num)
# _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
self._traj_buffer = {
env_id: {policy_id: TrajBuffer(maxlen=self._traj_len)
for policy_id in range(2)}
for env_id in range(self._env_num)
}
self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)}
self._episode_info = []
self._total_envstep_count = 0
self._total_episode_count = 0
self._total_train_sample_count = 0
self._total_duration = 0
self._last_train_iter = 0
self._end_flag = False
def _reset_stat(self, env_id: int) -> None:
"""
Overview:
Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
to get more messages.
Arguments:
- env_id (:obj:`int`): the id where we need to reset the collector's state
"""
for i in range(2):
self._traj_buffer[env_id][i].clear()
self._obs_pool.reset(env_id)
self._policy_output_pool.reset(env_id)
self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0}
@property
def envstep(self) -> int:
"""
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
"""
return self._total_envstep_count
def close(self) -> None:
"""
Overview:
Close the collector. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self) -> None:
"""
Overview:
Execute the close command and close the collector. __del__ is automatically called to \
destroy the collector instance when the collector finishes its work
"""
self.close()
def collect(self,
n_sample: Optional[int] = None,
train_iter: int = 0,
policy_kwargs: Optional[dict] = None) -> Tuple[List[Any], List[Any]]:
"""
Overview:
Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations
Arguments:
- n_sample (:obj:`int`): the number of collecting data sample
- train_iter (:obj:`int`): the number of training iteration
- policy_kwargs (:obj:`dict`): the keyword args for policy forward
Returns:
- return_data (:obj:`List`): A list containing training samples.
"""
if n_sample is None:
if self._default_n_sample is None:
raise RuntimeError("Please specify collect n_sample")
else:
n_sample = self._default_n_sample
if n_sample % self._env_num != 0:
one_time_warning(
"Please make sure env_num is divisible by n_sample: {}/{}, which may cause convergence \
problems in a few algorithms".format(n_sample, self._env_num)
)
if policy_kwargs is None:
policy_kwargs = {}
collected_sample = [0 for _ in range(2)]
return_data = [[] for _ in range(2)]
return_info = [[] for _ in range(2)]
while any([c < n_sample for c in collected_sample]):
with self._timer:
# Get current env obs.
obs = self._env.ready_obs
# Policy forward.
self._obs_pool.update(obs)
if self._transform_obs:
obs = to_tensor(obs, dtype=torch.float32)
obs = dicts_to_lists(obs)
policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)]
self._policy_output_pool.update(policy_output)
# Interact with env.
actions = {}
for policy_output_item in policy_output:
for env_id, output in policy_output_item.items():
if env_id not in actions:
actions[env_id] = []
actions[env_id].append(output['action'])
actions = to_ndarray(actions)
timesteps = self._env.step(actions)
# TODO(nyz) this duration may be inaccurate in async env
interaction_duration = self._timer.value / len(timesteps)
# TODO(nyz) vectorize this for loop
for env_id, timestep in timesteps.items():
self._env_info[env_id]['step'] += 1
self._total_envstep_count += 1
with self._timer:
for policy_id, policy in enumerate(self._policy):
policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep]
policy_timestep = type(timestep)(*policy_timestep_data)
transition = self._policy[policy_id].process_transition(
self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id],
policy_timestep
)
transition['collect_iter'] = train_iter
self._traj_buffer[env_id][policy_id].append(transition)
# prepare data
if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len:
transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
train_sample = self._policy[policy_id].get_train_sample(transitions)
return_data[policy_id].extend(train_sample)
self._total_train_sample_count += len(train_sample)
self._env_info[env_id]['train_sample'] += len(train_sample)
collected_sample[policy_id] += len(train_sample)
self._traj_buffer[env_id][policy_id].clear()
self._env_info[env_id]['time'] += self._timer.value + interaction_duration
# If env is done, record episode info and reset
if timestep.done:
self._total_episode_count += 1
info = {
'reward0': timestep.info[0]['final_eval_reward'],
'reward1': timestep.info[1]['final_eval_reward'],
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'train_sample': self._env_info[env_id]['train_sample'],
}
self._episode_info.append(info)
for i, p in enumerate(self._policy):
p.reset([env_id])
self._reset_stat(env_id)
for policy_id in range(2):
return_info[policy_id].append(timestep.info[policy_id])
# log
self._output_log(train_iter)
return_data = [r[:n_sample] for r in return_data]
return return_data, return_info
def _output_log(self, train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to Docs/Best Practice/How to understand\
training generated folders/Serial mode/log/collector for more details.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
self._last_train_iter = train_iter
episode_count = len(self._episode_info)
envstep_count = sum([d['step'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_reward0 = [d['reward0'] for d in self._episode_info]
episode_reward1 = [d['reward1'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward0_mean': np.mean(episode_reward0),
'reward0_std': np.std(episode_reward0),
'reward0_max': np.max(episode_reward0),
'reward0_min': np.min(episode_reward0),
'reward1_mean': np.mean(episode_reward1),
'reward1_std': np.std(episode_reward1),
'reward1_max': np.max(episode_reward1),
'reward1_min': np.min(episode_reward1),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
}
self._episode_info.clear()
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
for k, v in info.items():
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
if k in ['total_envstep_count']:
continue
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
...@@ -14,8 +14,8 @@ pong_ppo_config = dict( ...@@ -14,8 +14,8 @@ pong_ppo_config = dict(
), ),
policy=dict( policy=dict(
cuda=True, cuda=True,
# (bool) whether to use on-policy training pipeline(on-policy means behaviour policy and training policy are the same)
on_policy=False, on_policy=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
model=dict( model=dict(
obs_shape=[4, 84, 84], obs_shape=[4, 84, 84],
action_shape=6, action_shape=6,
......
...@@ -120,7 +120,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -120,7 +120,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
main_player = league.get_player_by_id(main_key) main_player = league.get_player_by_id(main_key)
main_learner = learners[main_key] main_learner = learners[main_key]
main_collector = collectors[main_key] main_collector = collectors[main_key]
# collect_mode ppo use multimonial sample for selecting action # collect_mode ppo use multinomial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0] evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = BattleInteractionSerialEvaluator( evaluator1 = BattleInteractionSerialEvaluator(
......
...@@ -87,7 +87,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -87,7 +87,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
tb_logger, tb_logger,
exp_name=cfg.exp_name exp_name=cfg.exp_name
) )
# collect_mode ppo use multimonial sample for selecting action # collect_mode ppo use multinomial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0] evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = BattleInteractionSerialEvaluator( evaluator1 = BattleInteractionSerialEvaluator(
......
from .slime_volley_league_ppo_config import slime_volley_league_ppo_config
from easydict import EasyDict
slime_volley_league_ppo_config = dict(
exp_name="slime_volley_league_ppo",
env=dict(
collector_env_num=8,
evaluator_env_num=10,
n_evaluator_episode=100,
stop_value=0,
# Single-agent env for evaluator; Double-agent env for collector.
# Should be assigned True or False in code.
is_evaluator=None,
manager=dict(shared_memory=False, ),
env_id="SlimeVolley-v0",
),
policy=dict(
cuda=False,
continuous=False,
model=dict(
obs_shape=12,
action_shape=6,
encoder_hidden_size_list=[32, 32],
critic_head_hidden_size=32,
actor_head_hidden_size=32,
share_encoder=False,
),
learn=dict(
update_per_collect=3,
batch_size=32,
learning_rate=0.00001,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
),
collect=dict(
n_episode=128, unroll_len=1, discount_factor=1.0, gae_lambda=1.0, collector=dict(get_train_sample=True, )
),
other=dict(
league=dict(
player_category=['default'],
path_policy="slime_volley_league_ppo/policy",
active_players=dict(
main_player=1,
main_exploiter=1,
league_exploiter=1,
),
main_player=dict(
one_phase_step=200,
branch_probs=dict(
pfsp=0.5,
sp=1.0,
),
strong_win_rate=0.7,
),
main_exploiter=dict(
one_phase_step=200,
branch_probs=dict(main_players=1.0, ),
strong_win_rate=0.7,
min_valid_win_rate=0.3,
),
league_exploiter=dict(
one_phase_step=200,
branch_probs=dict(pfsp=1.0, ),
strong_win_rate=0.7,
mutate_prob=0.0,
),
use_pretrain=False,
use_pretrain_init_historical=False,
payoff=dict(
type='battle',
decay=0.99,
min_win_rate_games=8,
)
),
),
),
)
slime_volley_league_ppo_config = EasyDict(slime_volley_league_ppo_config)
from easydict import EasyDict
from ding.entry import serial_pipeline_onpolicy
slime_volley_ppo_config = dict(
exp_name='slime_volley_ppo',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
agent_vs_agent=False,
stop_value=1000000,
env_id="SlimeVolley-v0",
),
policy=dict(
cuda=True,
on_policy=True,
continuous=False,
model=dict(
obs_shape=12,
action_shape=6,
encoder_hidden_size_list=[64, 64],
critic_head_hidden_size=64,
actor_head_hidden_size=64,
share_encoder=False,
),
learn=dict(
epoch_per_collect=5,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
),
collect=dict(
n_sample=4096,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
slime_volley_ppo_config = EasyDict(slime_volley_ppo_config)
main_config = slime_volley_ppo_config
slime_volley_ppo_create_config = dict(
env=dict(
type='slime_volley',
import_names=['dizoo.slime_volley.envs.slime_volley_env'],
),
env_manager=dict(type='subprocess'), # save replay must use base
policy=dict(type='ppo'),
)
slime_volley_ppo_create_config = EasyDict(slime_volley_ppo_create_config)
create_config = slime_volley_ppo_create_config
if __name__ == "__main__":
serial_pipeline_onpolicy([main_config, create_config], seed=0)
import os
import gym
import numpy as np
import copy
import torch
from tensorboardX import SummaryWriter
from functools import partial
from ding.config import compile_config
from ding.worker import BaseLearner, BattleSampleSerialCollector, NaiveReplayBuffer, InteractionSerialEvaluator
from ding.envs import SyncSubprocessEnvManager
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed
from dizoo.slime_volley.envs import SlimeVolleyEnv
from dizoo.slime_volley.config.slime_volley_ppo_config import main_config
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
SyncSubprocessEnvManager,
PPOPolicy,
BaseLearner,
BattleSampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env_cfg = copy.deepcopy(cfg.env)
collector_env_cfg.agent_vs_agent = True
evaluator_env_cfg = copy.deepcopy(cfg.env)
evaluator_env_cfg.agent_vs_agent = False
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(SlimeVolleyEnv, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(SlimeVolleyEnv, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(
cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1'
)
collector = BattleSampleSerialCollector(
cfg.policy.collect.collector,
collector_env, [policy.collect_mode, policy.collect_mode],
tb_logger,
exp_name=cfg.exp_name
)
evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator_cfg.stop_value = cfg.env.stop_value
evaluator = InteractionSerialEvaluator(
evaluator_cfg,
evaluator_env,
policy.eval_mode,
tb_logger,
exp_name=cfg.exp_name,
instance_name='builtin_ai_evaluator'
)
learner.call_hook('before_run')
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop_flag, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop_flag:
break
new_data, _ = collector.collect(train_iter=learner.train_iter)
train_data = new_data[0] + new_data[1]
learner.train(train_data, collector.envstep)
learner.call_hook('after_run')
if __name__ == "__main__":
main(main_config)
from .slime_volley_env import SlimeVolleyEnv
from namedlist import namedlist
import numpy as np
import gym
from typing import Any, Union, List, Optional
import copy
import slimevolleygym
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.utils import ENV_REGISTRY
from ding.torch_utils import to_tensor, to_ndarray
class GymSelfPlayMonitor(gym.wrappers.Monitor):
def step(self, *args, **kwargs):
self._before_step(*args, **kwargs)
observation, reward, done, info = self.env.step(*args, **kwargs)
done = self._after_step(observation, reward, done, info)
return observation, reward, done, info
def _before_step(self, *args, **kwargs):
if not self.enabled:
return
self.stats_recorder.before_step(args[0])
@ENV_REGISTRY.register('slime_volley')
class SlimeVolleyEnv(BaseEnv):
def __init__(self, cfg) -> None:
self._cfg = cfg
self._init_flag = False
self._replay_path = None
# agent_vs_bot env is single-agent env. obs, action, done, info are all single.
# agent_vs_agent env is double-agent env, obs, action, info are double, done is still single.
self._agent_vs_agent = cfg.agent_vs_agent
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False
def step(self, action: Union[np.ndarray, List[np.ndarray]]):
if self._agent_vs_agent:
assert isinstance(action, list) and isinstance(action[0], np.ndarray)
action1, action2 = action[0], action[1]
else:
assert isinstance(action, np.ndarray)
action1, action2 = action, None
assert isinstance(action1, np.ndarray), type(action1)
assert action2 is None or isinstance(action1, np.ndarray), type(action2)
if action1.shape == (1, ):
action1 = action1.squeeze() # 0-dim tensor
if action2 is not None and action2.shape == (1, ):
action2 = action2.squeeze() # 0-dim tensor
action1 = SlimeVolleyEnv._process_action(action1)
action2 = SlimeVolleyEnv._process_action(action2)
obs1, rew, done, info = self._env.step(action1, action2)
obs1 = to_ndarray(obs1).astype(np.float32)
self._final_eval_reward += rew
# info ('ale.lives', 'ale.otherLives', 'otherObs', 'state', 'otherState')
if self._agent_vs_agent:
info = [
{
'ale.lives': info['ale.lives'],
'state': info['state']
}, {
'ale.lives': info['ale.otherLives'],
'state': info['otherState'],
'obs': info['otherObs']
}
]
if done:
info[0]['final_eval_reward'] = self._final_eval_reward
info[1]['final_eval_reward'] = -self._final_eval_reward
else:
if done:
info['final_eval_reward'] = self._final_eval_reward
reward = to_ndarray([rew]).astype(np.float32)
if self._agent_vs_agent:
obs2 = info[1]['obs']
obs2 = to_ndarray(obs2).astype(np.float32)
observations = np.stack([obs1, obs2], axis=0)
rewards = to_ndarray([rew, -rew]).astype(np.float32)
rewards = rewards[..., np.newaxis]
return BaseEnvTimestep(observations, rewards, done, info)
else:
return BaseEnvTimestep(obs1, reward, done, info)
def reset(self):
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
if self._replay_path is not None:
self._env = GymSelfPlayMonitor(
self._env, self._replay_path, video_callable=lambda episode_id: True, force=True
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
self._final_eval_reward = 0
obs = self._env.reset()
obs = to_ndarray(obs).astype(np.float32)
if self._agent_vs_agent:
obs = np.stack([obs, obs], axis=0)
return obs
else:
return obs
def info(self):
T = EnvElementInfo
return BaseEnvInfo(
agent_num=2,
obs_space=T(
(2, 12) if self._agent_vs_agent else (12, ),
{
'min': [float("-inf") for _ in range(12)],
'max': [float("inf") for _ in range(12)],
'dtype': np.float32,
},
),
# [min, max)
# 6 valid actions:
act_space=T(
(1, ),
{
'min': 0,
'max': 6,
'dtype': int,
},
),
rew_space=T(
(1, ),
{
'min': -5.0,
'max': 5.0,
'dtype': np.float32,
},
),
use_wrappers=None,
)
def __repr__(self):
return "DI-engine Slime Volley Env"
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
@staticmethod
def _process_action(action: np.ndarray, _type: str = "binary") -> np.ndarray:
if action is None:
return None
action = action.item()
# Env receives action in [0, 5] (int type). Can translater into:
# 1) "binary" type: np.array([0, 1, 0])
# 2) "atari" type: NOOP, LEFT, UPLEFT, UP, UPRIGHT, RIGHT
to_atari_action = {
0: 0, # NOOP
1: 4, # LEFT
2: 7, # UPLEFT
3: 2, # UP
4: 6, # UPRIGHT
5: 3, # RIGHT
}
to_binary_action = {
0: [0, 0, 0], # NOOP
1: [1, 0, 0], # LEFT (forward)
2: [1, 0, 1], # UPLEFT (forward jump)
3: [0, 0, 1], # UP (jump)
4: [0, 1, 1], # UPRIGHT (backward jump)
5: [0, 1, 0], # RIGHT (backward)
}
if _type == "binary":
return to_ndarray(to_binary_action[action])
elif _type == "atari":
return to_atari_action[action]
else:
raise NotImplementedError
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.slime_volley.envs.slime_volley_env import SlimeVolleyEnv
@pytest.mark.envtest
class TestSlimeVolley:
@pytest.mark.parametrize('agent_vs_agent', [True, False])
def test_slime_volley(self, agent_vs_agent):
total_rew = 0
env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent}))
# env.enable_save_replay('replay_video')
obs1 = env.reset()
done = False
print(env._env.observation_space)
print('observation is like:', obs1)
done = False
while not done:
if agent_vs_agent:
action1 = np.random.randint(0, 2, (1, ))
action2 = np.random.randint(0, 2, (1, ))
action = [action1, action2]
else:
action = np.random.randint(0, 2, (1, ))
observations, rewards, done, infos = env.step(action)
total_rew += rewards[0]
obs1, obs2 = observations[0], observations[1]
assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape)
if agent_vs_agent:
agent_lives, opponent_lives = infos[0]['ale.lives'], infos[1]['ale.lives']
if agent_vs_agent:
assert agent_lives == 0 or opponent_lives == 0, (agent_lives, opponent_lives)
print("total reward is:", total_rew)
...@@ -136,6 +136,10 @@ setup( ...@@ -136,6 +136,10 @@ setup(
'whichcraft', 'whichcraft',
'joblib', 'joblib',
], ],
'slimevolleygym_env': [
'slimevolleygym',
],
'k8s': [ 'k8s': [
'kubernetes', 'kubernetes',
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册