subprocess_env_manager.py 33.8 KB
Newer Older
1 2
from typing import Any, Union, List, Tuple, Dict, Callable, Optional
from multiprocessing import Pipe, connection, get_context, Array
N
v0.1.0  
niuyazhe 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
from collections import namedtuple
import logging
import platform
import time
import copy
import traceback
import numpy as np
import torch
import ctypes
import pickle
import cloudpickle
from easydict import EasyDict
from types import MethodType

from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY
18 19
from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper
from ding.envs.env.base_env import BaseEnvTimestep
N
v0.1.0  
niuyazhe 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

_NTYPE_TO_CTYPE = {
    np.bool_: ctypes.c_bool,
    np.uint8: ctypes.c_uint8,
    np.uint16: ctypes.c_uint16,
    np.uint32: ctypes.c_uint32,
    np.uint64: ctypes.c_uint64,
    np.int8: ctypes.c_int8,
    np.int16: ctypes.c_int16,
    np.int32: ctypes.c_int32,
    np.int64: ctypes.c_int64,
    np.float32: ctypes.c_float,
    np.float64: ctypes.c_double,
}


36 37 38 39 40 41 42 43 44
def is_abnormal_timestep(timestep: namedtuple) -> bool:
    if isinstance(timestep.info, dict):
        return timestep.info.get('abnormal', False)
    elif isinstance(timestep.info, list) or isinstance(timestep.info, tuple):
        return timestep.info[0].get('abnormal', False) or timestep.info[1].get('abnormal', False)
    else:
        raise TypeError("invalid env timestep type: {}".format(type(timestep.info)))


N
v0.1.0  
niuyazhe 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
class ShmBuffer():
    """
    Overview:
        Shared memory buffer to store numpy array.
    """

    def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
        """
        Overview:
            Initialize the buffer.
        Arguments:
            - dtype (:obj:`np.generic`): dtype of the data to limit the size of the buffer.
            - shape (:obj:`Tuple[int]`): shape of the data to limit the size of the buffer.
        """
        self.buffer = Array(_NTYPE_TO_CTYPE[dtype.type], int(np.prod(shape)))
        self.dtype = dtype
        self.shape = shape

    def fill(self, src_arr: np.ndarray) -> None:
        """
        Overview:
            Fill the shared memory buffer with a numpy array. (Replace the original one.)
        Arguments:
            - src_arr (:obj:`np.ndarray`): array to fill the buffer.
        """
        assert isinstance(src_arr, np.ndarray), type(src_arr)
        dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
        with self.buffer.get_lock():
            np.copyto(dst_arr, src_arr)

    def get(self) -> np.ndarray:
        """
        Overview:
            Get the array stored in the buffer.
        Return:
            - copy_data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
        """
        arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
        return arr.copy()


class ShmBufferContainer(object):
    """
    Overview:
        Support multiple shared memory buffers. Each key-value is name-buffer.
    """

    def __init__(self, dtype: np.generic, shape: Union[Dict[Any, tuple], tuple]) -> None:
        """
        Overview:
            Initialize the buffer container.
        Arguments:
            - dtype (:obj:`np.generic`): dtype of the data to limit the size of the buffer.
            - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
                multiple buffers; If `tuple`, use single buffer.
        """
        if isinstance(shape, dict):
            self._data = {k: ShmBufferContainer(dtype, v) for k, v in shape.items()}
        elif isinstance(shape, (tuple, list)):
            self._data = ShmBuffer(dtype, shape)
        else:
            raise RuntimeError("not support shape: {}".format(shape))
        self._shape = shape

    def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
        """
        Overview:
            Fill the one or many shared memory buffer.
        Arguments:
            - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
        """
        if isinstance(self._shape, dict):
            for k in self._shape.keys():
                self._data[k].fill(src_arr[k])
        elif isinstance(self._shape, (tuple, list)):
            self._data.fill(src_arr)

    def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
        """
        Overview:
            Get the one or many arrays stored in the buffer.
        Return:
            - data (:obj:`np.ndarray`): The array(s) stored in the buffer.
        """
        if isinstance(self._shape, dict):
            return {k: self._data[k].get() for k in self._shape.keys()}
        elif isinstance(self._shape, (tuple, list)):
            return self._data.get()


class CloudPickleWrapper:
    """
    Overview:
        CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression)
    """

    def __init__(self, data: Any) -> None:
        self.data = data

    def __getstate__(self) -> bytes:
        return cloudpickle.dumps(self.data)

    def __setstate__(self, data: bytes) -> None:
        if isinstance(data, (tuple, list, np.ndarray)):  # pickle is faster
            self.data = pickle.loads(data)
        else:
            self.data = cloudpickle.loads(data)


@ENV_MANAGER_REGISTRY.register('async_subprocess')
class AsyncSubprocessEnvManager(BaseEnvManager):
    """
    Overview:
        Create an AsyncSubprocessEnvManager to manage multiple environments.
        Each Environment is run by a respective subprocess.
    Interfaces:
        seed, launch, ready_obs, step, reset, env_info,active_env
    """

    config = dict(
        episode_num=float("inf"),
        max_retry=5,
167
        step_timeout=None,
N
v0.1.0  
niuyazhe 已提交
168
        auto_reset=True,
169 170
        retry_type='reset',
        reset_timeout=None,
N
v0.1.0  
niuyazhe 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        retry_waiting_time=0.1,
        # subprocess specified args
        shared_memory=True,
        context='spawn' if platform.system().lower() == 'windows' else 'fork',
        wait_num=2,
        step_wait_timeout=0.01,
        connect_timeout=60,
    )

    def __init__(
            self,
            env_fn: List[Callable],
            cfg: EasyDict = EasyDict({}),
    ) -> None:
        """
        Overview:
            Initialize the AsyncSubprocessEnvManager.
        Arguments:
            - env_fn (:obj:`List[Callable]`): The function to create environment
            - cfg (:obj:`EasyDict`): Config
191 192 193 194 195

        .. note::

            - wait_num: for each time the minimum number of env return to gather
            - step_wait_timeout: for each time the minimum number of env return to gather
N
v0.1.0  
niuyazhe 已提交
196 197 198 199 200 201 202 203 204
        """
        super().__init__(env_fn, cfg)
        self._shared_memory = self._cfg.shared_memory
        self._context = self._cfg.context
        self._wait_num = self._cfg.wait_num
        self._step_wait_timeout = self._cfg.step_wait_timeout

        self._lock = LockContext(LockContextType.THREAD_LOCK)
        self._connect_timeout = self._cfg.connect_timeout
205 206 207 208 209 210
        self._async_args = {
            'step': {
                'wait_num': min(self._wait_num, self._env_num),
                'timeout': self._step_wait_timeout
            }
        }
N
v0.1.0  
niuyazhe 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273

    def _create_state(self) -> None:
        r"""
        Overview:
            Fork/spawn sub-processes(Call ``_create_env_subprocess``) and create pipes to transfer the data.
        """
        self._env_episode_count = {env_id: 0 for env_id in range(self.env_num)}
        self._ready_obs = {env_id: None for env_id in range(self.env_num)}
        self._env_ref = self._env_fn[0]()
        self._reset_param = {i: {} for i in range(self.env_num)}
        if self._shared_memory:
            obs_space = self._env_ref.info().obs_space
            shape = obs_space.shape
            dtype = np.dtype(obs_space.value['dtype']) if obs_space.value is not None else np.dtype(np.float32)
            self._obs_buffers = {env_id: ShmBufferContainer(dtype, shape) for env_id in range(self.env_num)}
        else:
            self._obs_buffers = {env_id: None for env_id in range(self.env_num)}
        self._pipe_parents, self._pipe_children = {}, {}
        self._subprocesses = {}
        for env_id in range(self.env_num):
            self._create_env_subprocess(env_id)
        self._waiting_env = {'step': set()}
        self._closed = False

    def _create_env_subprocess(self, env_id):
        # start a new one
        self._pipe_parents[env_id], self._pipe_children[env_id] = Pipe()
        ctx = get_context(self._context)
        self._subprocesses[env_id] = ctx.Process(
            # target=self.worker_fn,
            target=self.worker_fn_robust,
            args=(
                self._pipe_parents[env_id],
                self._pipe_children[env_id],
                CloudPickleWrapper(self._env_fn[env_id]),
                self._obs_buffers[env_id],
                self.method_name_list,
                self._reset_timeout,
                self._step_timeout,
            ),
            daemon=True,
            name='subprocess_env_manager{}_{}'.format(env_id, time.time())
        )
        self._subprocesses[env_id].start()
        self._pipe_children[env_id].close()
        self._env_states[env_id] = EnvState.INIT

        if self._env_replay_path is not None:
            self._pipe_parents[env_id].send(['enable_save_replay', [self._env_replay_path[env_id]], {}])
            self._pipe_parents[env_id].recv()

    @property
    def ready_env(self) -> List[int]:
        return [i for i in self.active_env if i not in self._waiting_env['step']]

    @property
    def ready_obs(self) -> Dict[int, Any]:
        """
        Overview:
            Get the next observations.
        Return:
            A dictionary with observations and their environment IDs.
        Note:
274
            The observations are returned in np.ndarray.
N
v0.1.0  
niuyazhe 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
        Example:
            >>>     obs_dict = env_manager.ready_obs
            >>>     actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
        """
        no_done_env_idx = [i for i, s in self._env_states.items() if s != EnvState.DONE]
        sleep_count = 0
        while not any([self._env_states[i] == EnvState.RUN for i in no_done_env_idx]):
            if sleep_count % 1000 == 0:
                logging.warning(
                    'VEC_ENV_MANAGER: all the not done envs are resetting, sleep {} times'.format(sleep_count)
                )
            time.sleep(0.001)
            sleep_count += 1
        return {i: self._ready_obs[i] for i in self.ready_env}

    def launch(self, reset_param: Optional[Dict] = None) -> None:
        """
        Overview:
            Set up the environments and their parameters.
        Arguments:
            - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \
                value is the cooresponding reset parameters.
        """
        assert self._closed, "please first close the env manager"
        if reset_param is not None:
            assert len(reset_param) == len(self._env_fn)
        self._create_state()
        self.reset(reset_param)

    def reset(self, reset_param: Optional[Dict] = None) -> None:
        """
        Overview:
            Reset the environments their parameters.
        Arguments:
            - reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \
                value is the cooresponding reset parameters.
        """
        self._check_closed()
        # clear previous info
        for env_id in self._waiting_env['step']:
            self._pipe_parents[env_id].recv()
        self._waiting_env['step'].clear()

        if reset_param is None:
            reset_env_list = [env_id for env_id in range(self._env_num)]
        else:
            reset_env_list = reset_param.keys()
            for env_id in reset_param:
                self._reset_param[env_id] = reset_param[env_id]

        sleep_count = 0
        while any([self._env_states[i] == EnvState.RESET for i in reset_env_list]):
            if sleep_count % 1000 == 0:
                logging.warning(
                    'VEC_ENV_MANAGER: not all the envs finish resetting, sleep {} times'.format(sleep_count)
                )
            time.sleep(0.001)
            sleep_count += 1

        # reset env
        reset_thread_list = []
        for i, env_id in enumerate(reset_env_list):
337
            self._env_states[env_id] = EnvState.RESET
N
v0.1.0  
niuyazhe 已提交
338 339 340 341 342 343 344 345 346 347
            # set seed
            if self._env_seed[env_id] is not None:
                try:
                    if self._env_dynamic_seed is not None:
                        self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id], self._env_dynamic_seed], {}])
                    else:
                        self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id]], {}])
                    ret = self._pipe_parents[env_id].recv()
                    self._check_data({env_id: ret})
                    self._env_seed[env_id] = None  # seed only use once
348
                except BaseException as e:
N
v0.1.0  
niuyazhe 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
                    logging.warning("subprocess reset set seed failed, ignore and continue...")
            reset_thread = PropagatingThread(target=self._reset, args=(env_id, ))
            reset_thread.daemon = True
            reset_thread_list.append(reset_thread)

        for t in reset_thread_list:
            t.start()
        for t in reset_thread_list:
            t.join()

    def _reset(self, env_id: int) -> None:

        def reset_fn():
            if self._pipe_parents[env_id].poll():
                recv_data = self._pipe_parents[env_id].recv()
364
                raise RuntimeError("unread data left before sending to the pipe: {}".format(repr(recv_data)))
N
v0.1.0  
niuyazhe 已提交
365 366 367 368 369 370 371 372
            # if self._reset_param[env_id] is None, just reset specific env, not pass reset param
            if self._reset_param[env_id] is not None:
                assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id])
                self._pipe_parents[env_id].send(['reset', [], self._reset_param[env_id]])
            else:
                self._pipe_parents[env_id].send(['reset', [], {}])

            if not self._pipe_parents[env_id].poll(self._connect_timeout):
373
                raise ConnectionError("env reset connection timeout")  # Leave it to try again
N
v0.1.0  
niuyazhe 已提交
374 375 376 377 378 379 380 381 382

            obs = self._pipe_parents[env_id].recv()
            self._check_data({env_id: obs}, close=False)
            if self._shared_memory:
                obs = self._obs_buffers[env_id].get()
            # Because each thread updates the corresponding env_id value, they won't lead to a thread-safe problem.
            self._env_states[env_id] = EnvState.RUN
            self._ready_obs[env_id] = obs

383 384 385 386
        exceptions = []
        for _ in range(self._max_retry):
            try:
                reset_fn()
N
v0.1.0  
niuyazhe 已提交
387
                return
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
            except BaseException as e:
                if self._retry_type == 'renew' or isinstance(e, pickle.UnpicklingError):
                    self._pipe_parents[env_id].close()
                    if self._subprocesses[env_id].is_alive():
                        self._subprocesses[env_id].terminate()
                    self._create_env_subprocess(env_id)
                exceptions.append(e)
                time.sleep(self._retry_waiting_time)

        logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry))
        runtime_error = RuntimeError(
            "Env {} reset has exceeded max retries({}), and the latest exception is: {}".format(
                env_id, self._max_retry, repr(exceptions[-1])
            )
        )
        runtime_error.__traceback__ = exceptions[-1].__traceback__
        if self._closed:  # exception cased by main thread closing parent_remote
            return
        else:
            self.close()
            raise runtime_error
N
v0.1.0  
niuyazhe 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452

    def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
        """
        Overview:
            Step all environments. Reset an env if done.
        Arguments:
            - actions (:obj:`Dict[int, Any]`): {env_id: action}
        Returns:
            - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \
                ``BaseEnvTimestep`` tuple with observation, reward, done, env_info.
        Example:
            >>>     actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
            >>>     timesteps = env_manager.step(actions_dict):
            >>>     for env_id, timestep in timesteps.items():
            >>>         pass

        .. note:

            - The env_id that appears in ``actions`` will also be returned in ``timesteps``.
            - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.
            - Async subprocess env manager use ``connection.wait`` to poll.
        """
        self._check_closed()
        env_ids = list(actions.keys())
        assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids]
                   ), 'current env state are: {}, please check whether the requested env is in reset or done'.format(
                       {env_id: self._env_states[env_id]
                        for env_id in env_ids}
                   )

        for env_id, act in actions.items():
            self._pipe_parents[env_id].send(['step', [act], {}])

        timesteps = {}
        step_args = self._async_args['step']
        wait_num, timeout = min(step_args['wait_num'], len(env_ids)), step_args['timeout']
        rest_env_ids = list(set(env_ids).union(self._waiting_env['step']))
        ready_env_ids = []
        cur_rest_env_ids = copy.deepcopy(rest_env_ids)
        while True:
            rest_conn = [self._pipe_parents[env_id] for env_id in cur_rest_env_ids]
            ready_conn, ready_ids = AsyncSubprocessEnvManager.wait(rest_conn, min(wait_num, len(rest_conn)), timeout)
            cur_ready_env_ids = [cur_rest_env_ids[env_id] for env_id in ready_ids]
            assert len(cur_ready_env_ids) == len(ready_conn)
453 454 455 456 457 458 459 460 461 462 463
            # timesteps.update({env_id: p.recv() for env_id, p in zip(cur_ready_env_ids, ready_conn)})
            for env_id, p in zip(cur_ready_env_ids, ready_conn):
                try:
                    timesteps.update({env_id: p.recv()})
                except pickle.UnpicklingError as e:
                    timestep = BaseEnvTimestep(None, None, None, {'abnormal': True})
                    timesteps.update({env_id: timestep})
                    self._pipe_parents[env_id].close()
                    if self._subprocesses[env_id].is_alive():
                        self._subprocesses[env_id].terminate()
                    self._create_env_subprocess(env_id)
N
v0.1.0  
niuyazhe 已提交
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
            self._check_data(timesteps)
            ready_env_ids += cur_ready_env_ids
            cur_rest_env_ids = list(set(cur_rest_env_ids).difference(set(cur_ready_env_ids)))
            # At least one not done env timestep, or all envs' steps are finished
            if any([not t.done for t in timesteps.values()]) or len(ready_conn) == len(rest_conn):
                break
        self._waiting_env['step']: set
        for env_id in rest_env_ids:
            if env_id in ready_env_ids:
                if env_id in self._waiting_env['step']:
                    self._waiting_env['step'].remove(env_id)
            else:
                self._waiting_env['step'].add(env_id)

        if self._shared_memory:
            for i, (env_id, timestep) in enumerate(timesteps.items()):
                timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())

        for env_id, timestep in timesteps.items():
483
            if is_abnormal_timestep(timestep):
N
v0.1.0  
niuyazhe 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
                self._env_states[env_id] = EnvState.ERROR
                continue
            if timestep.done:
                self._env_episode_count[env_id] += 1
                if self._env_episode_count[env_id] < self._episode_num and self._auto_reset:
                    self._env_states[env_id] = EnvState.RESET
                    reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset')
                    reset_thread.daemon = True
                    reset_thread.start()
                else:
                    self._env_states[env_id] = EnvState.DONE
            else:
                self._ready_obs[env_id] = timestep.obs
        return timesteps

    # This method must be staticmethod, otherwise there will be some resource conflicts(e.g. port or file)
    # Env must be created in worker, which is a trick of avoiding env pickle errors.
    # A more robust version is used by default. But this one is also preserved.
    @staticmethod
    def worker_fn(
            p: connection.Connection, c: connection.Connection, env_fn_wrapper: 'CloudPickleWrapper',
            obs_buffer: ShmBuffer, method_name_list: list
    ) -> None:  # noqa
        """
        Overview:
            Subprocess's target function to run.
        """
        torch.set_num_threads(1)
        env_fn = env_fn_wrapper.data
        env = env_fn()
        p.close()
        try:
            while True:
                try:
                    cmd, args, kwargs = c.recv()
                except EOFError:  # for the case when the pipe has been closed
                    c.close()
                    break
                try:
                    if cmd == 'getattr':
                        ret = getattr(env, args[0])
                    elif cmd in method_name_list:
                        if cmd == 'step':
                            timestep = env.step(*args, **kwargs)
528
                            if is_abnormal_timestep(timestep):
N
v0.1.0  
niuyazhe 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
                                ret = timestep
                            else:
                                if obs_buffer is not None:
                                    obs_buffer.fill(timestep.obs)
                                    timestep = timestep._replace(obs=None)
                                ret = timestep
                        elif cmd == 'reset':
                            ret = env.reset(*args, **kwargs)  # obs
                            if obs_buffer is not None:
                                obs_buffer.fill(ret)
                                ret = None
                        elif args is None and kwargs is None:
                            ret = getattr(env, cmd)()
                        else:
                            ret = getattr(env, cmd)(*args, **kwargs)
                    else:
                        raise KeyError("not support env cmd: {}".format(cmd))
                    c.send(ret)
                except Exception as e:
                    # when there are some errors in env, worker_fn will send the errors to env manager
                    # directly send error to another process will lose the stack trace, so we create a new Exception
                    c.send(
                        e.__class__(
                            '\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e)
                        )
                    )
                if cmd == 'close':
                    c.close()
                    break
        except KeyboardInterrupt:
            c.close()

    @staticmethod
    def worker_fn_robust(
            parent,
            child,
            env_fn_wrapper,
            obs_buffer,
            method_name_list,
568 569
            reset_timeout=None,
            step_timeout=None,
N
v0.1.0  
niuyazhe 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582
    ) -> None:
        """
        Overview:
            A more robust version of subprocess's target function to run. Used by default.
        """
        torch.set_num_threads(1)
        env_fn = env_fn_wrapper.data
        env = env_fn()
        parent.close()

        @timeout_wrapper(timeout=step_timeout)
        def step_fn(*args, **kwargs):
            timestep = env.step(*args, **kwargs)
583
            if is_abnormal_timestep(timestep):
N
v0.1.0  
niuyazhe 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
                ret = timestep
            else:
                if obs_buffer is not None:
                    obs_buffer.fill(timestep.obs)
                    timestep = timestep._replace(obs=None)
                ret = timestep
            return ret

        # self._reset method has add retry_wrapper decorator
        @timeout_wrapper(timeout=reset_timeout)
        def reset_fn(*args, **kwargs):
            try:
                ret = env.reset(*args, **kwargs)
                if obs_buffer is not None:
                    obs_buffer.fill(ret)
                    ret = None
                return ret
601
            except BaseException as e:
N
v0.1.0  
niuyazhe 已提交
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
                env.close()
                raise e

        while True:
            try:
                cmd, args, kwargs = child.recv()
            except EOFError:  # for the case when the pipe has been closed
                child.close()
                break
            try:
                if cmd == 'getattr':
                    ret = getattr(env, args[0])
                elif cmd in method_name_list:
                    if cmd == 'step':
                        ret = step_fn(*args, **kwargs)
                    elif cmd == 'reset':
                        ret = reset_fn(*args, **kwargs)
                    elif args is None and kwargs is None:
                        ret = getattr(env, cmd)()
                    else:
                        ret = getattr(env, cmd)(*args, **kwargs)
                else:
                    raise KeyError("not support env cmd: {}".format(cmd))
                child.send(ret)
626 627
            except BaseException as e:
                logging.debug("Sub env '{}' error when executing {}".format(str(env), cmd))
N
v0.1.0  
niuyazhe 已提交
628 629 630 631 632 633 634 635 636 637 638 639
                # when there are some errors in env, worker_fn will send the errors to env manager
                # directly send error to another process will lose the stack trace, so we create a new Exception
                child.send(
                    e.__class__('\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e))
                )
            if cmd == 'close':
                child.close()
                break

    def _check_data(self, data: Dict, close: bool = True) -> None:
        exceptions = []
        for i, d in data.items():
640
            if isinstance(d, BaseException):
N
v0.1.0  
niuyazhe 已提交
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689
                self._env_states[i] = EnvState.ERROR
                exceptions.append(d)
        # when receiving env Exception, env manager will safely close and raise this Exception to caller
        if len(exceptions) > 0:
            if close:
                self.close()
            raise exceptions[0]

    # override
    def __getattr__(self, key: str) -> Any:
        self._check_closed()
        # we suppose that all the envs has the same attributes, if you need different envs, please
        # create different env managers.
        if not hasattr(self._env_ref, key):
            raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key))
        if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list:
            raise RuntimeError("env getattr doesn't supports method({}), please override method_name_list".format(key))
        for _, p in self._pipe_parents.items():
            p.send(['getattr', [key], {}])
        data = {i: p.recv() for i, p in self._pipe_parents.items()}
        self._check_data(data)
        ret = [data[i] for i in self._pipe_parents.keys()]
        return ret

    # override
    def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
        """
        Overview:
            Set each env's replay save path.
        Arguments:
            - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
                Or one path for all environments.
        """
        if isinstance(replay_path, str):
            replay_path = [replay_path] * self.env_num
        self._env_replay_path = replay_path

    # override
    def close(self) -> None:
        """
        Overview:
            CLose the env manager and release all related resources.
        """
        if self._closed:
            return
        self._closed = True
        self._env_ref.close()
        for _, p in self._pipe_parents.items():
            p.send(['close', None, None])
690 691 692
        for env_id, p in self._pipe_parents.items():
            if not p.poll(5):
                continue
N
v0.1.0  
niuyazhe 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
            p.recv()
        for i in range(self._env_num):
            self._env_states[i] = EnvState.VOID
        # disable process join for avoiding hang
        # for p in self._subprocesses:
        #     p.join()
        for _, p in self._subprocesses.items():
            p.terminate()
        for _, p in self._pipe_parents.items():
            p.close()

    @staticmethod
    def wait(rest_conn: list, wait_num: int, timeout: Optional[float] = None) -> Tuple[list, list]:
        """
        Overview:
            Wait at least enough(len(ready_conn) >= wait_num) connections within timeout constraint.
            If timeout is None and wait_num == len(ready_conn), means sync mode;
            If timeout is not None, will return when len(ready_conn) >= wait_num and
            this method takes more than timeout seconds.
        """
        assert 1 <= wait_num <= len(rest_conn
                                    ), 'please indicate proper wait_num: <wait_num: {}, rest_conn_num: {}>'.format(
                                        wait_num, len(rest_conn)
                                    )
        rest_conn_set = set(rest_conn)
        ready_conn = set()
        start_time = time.time()
        while len(rest_conn_set) > 0:
            if len(ready_conn) >= wait_num and timeout:
                if (time.time() - start_time) >= timeout:
                    break
            finish_conn = set(connection.wait(rest_conn_set, timeout=timeout))
            ready_conn = ready_conn.union(finish_conn)
            rest_conn_set = rest_conn_set.difference(finish_conn)
        ready_ids = [rest_conn.index(c) for c in ready_conn]
        return list(ready_conn), ready_ids


@ENV_MANAGER_REGISTRY.register('subprocess')
class SyncSubprocessEnvManager(AsyncSubprocessEnvManager):
733 734 735
    config = dict(
        episode_num=float("inf"),
        max_retry=5,
736
        step_timeout=None,
737
        auto_reset=True,
738 739
        reset_timeout=None,
        retry_type='reset',
740 741 742 743 744 745 746 747 748
        retry_waiting_time=0.1,
        # subprocess specified args
        shared_memory=True,
        context='spawn' if platform.system().lower() == 'windows' else 'fork',
        wait_num=float("inf"),  # inf mean all the environments
        step_wait_timeout=None,
        connect_timeout=60,
        force_reproducibility=False,
    )
N
v0.1.0  
niuyazhe 已提交
749

750 751 752 753 754 755 756
    def __init__(
            self,
            env_fn: List[Callable],
            cfg: EasyDict = EasyDict({}),
    ) -> None:
        super(SyncSubprocessEnvManager, self).__init__(env_fn, cfg)
        self._force_reproducibility = self._cfg.force_reproducibility
N
v0.1.0  
niuyazhe 已提交
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791

    def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
        """
        Overview:
            Step all environments. Reset an env if done.
        Arguments:
            - actions (:obj:`Dict[int, Any]`): {env_id: action}
        Returns:
            - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \
                ``BaseEnvTimestep`` tuple with observation, reward, done, env_info.
        Example:
            >>>     actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
            >>>     timesteps = env_manager.step(actions_dict):
            >>>     for env_id, timestep in timesteps.items():
            >>>         pass

        .. note::

            - The env_id that appears in ``actions`` will also be returned in ``timesteps``.
            - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.
        """
        self._check_closed()
        env_ids = list(actions.keys())
        assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids]
                   ), 'current env state are: {}, please check whether the requested env is in reset or done'.format(
                       {env_id: self._env_states[env_id]
                        for env_id in env_ids}
                   )
        for env_id, act in actions.items():
            self._pipe_parents[env_id].send(['step', [act], {}])

        # ===     This part is different from async one.     ===
        # === Because operate in this way is more efficient. ===
        timesteps = {}
        ready_conn = [self._pipe_parents[env_id] for env_id in env_ids]
792 793 794 795 796 797 798 799 800 801 802
        # timesteps.update({env_id: p.recv() for env_id, p in zip(env_ids, ready_conn)})
        for env_id, p in zip(env_ids, ready_conn):
            try:
                timesteps.update({env_id: p.recv()})
            except pickle.UnpicklingError as e:
                timestep = BaseEnvTimestep(None, None, None, {'abnormal': True})
                timesteps.update({env_id: timestep})
                self._pipe_parents[env_id].close()
                if self._subprocesses[env_id].is_alive():
                    self._subprocesses[env_id].terminate()
                self._create_env_subprocess(env_id)
N
v0.1.0  
niuyazhe 已提交
803 804 805 806 807 808 809
        self._check_data(timesteps)
        # ======================================================

        if self._shared_memory:
            for i, (env_id, timestep) in enumerate(timesteps.items()):
                timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())
        for env_id, timestep in timesteps.items():
810
            if is_abnormal_timestep(timestep):
N
v0.1.0  
niuyazhe 已提交
811 812 813 814 815 816 817 818 819
                self._env_states[env_id] = EnvState.ERROR
                continue
            if timestep.done:
                self._env_episode_count[env_id] += 1
                if self._env_episode_count[env_id] < self._episode_num and self._auto_reset:
                    self._env_states[env_id] = EnvState.RESET
                    reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset')
                    reset_thread.daemon = True
                    reset_thread.start()
820 821
                    if self._force_reproducibility:
                        reset_thread.join()
N
v0.1.0  
niuyazhe 已提交
822 823 824 825 826
                else:
                    self._env_states[env_id] = EnvState.DONE
            else:
                self._ready_obs[env_id] = timestep.obs
        return timesteps