spawn.py 24.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function, division

import multiprocessing
import os
import signal
import six
import sys
import warnings

R
Roc 已提交
24
from paddle.distributed.utils.launch_utils import _print_arguments, _prepare_trainer_env, get_host_name_ip
X
xiongkun 已提交
25 26
from paddle.distributed.cloud_utils import get_cluster_and_pod, _get_trainers_num
from paddle.distributed.fleet.launch import get_cluster_from_args
27
from paddle.distributed.fleet.cloud_utils import use_paddlecloud
X
xiongkun 已提交
28
from paddle.distributed.fleet.launch_utils import DeviceMode, check_backend, block_windows_and_macos
29 30 31 32
from paddle.device import get_device

# deprecated module import
from paddle.fluid import core
33
from paddle.fluid.framework import _cpu_num, set_flags
34

35 36
__all__ = []

37 38

class ParallelEnvArgs(object):
39

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    def __init__(self):
        # Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..
        self.cluster_node_ips = None

        # The current node ip.
        self.node_ip = None

        # whether to use paddlecloud platform to run your multi-process job.
        # If false, no need to set this argument.
        self.use_paddlecloud = None

        # The trainer's started port on a single node
        self.started_port = None

        # Print the config or not
        self.print_config = True

57 58 59
        # It's for gpu training and the training process will run
        # on the selected_devices, each process is bound to a single GPU.
        # And if it's not set, this module will use all the gpu cards
60
        # for training.
61
        self.selected_devices = None
62 63 64 65 66 67 68 69 70 71 72


def _py_supported_check():
    if not sys.version_info >= (3, 4):
        raise RuntimeError(
            "Use `paddle.distributed.spawn` to start parallel training "
            "requires python version greater than 3.4, if your python "
            "is lower than this version, please use "
            "`paddle.distributed.launch` instead.")


73
def _options_valid_check(options):
74
    # `print_config` keeped as a debug options, not show to users
X
xiongkun 已提交
75
    supported_options = [
76
        'start_method', 'ips', 'gpus', 'xpus', 'mlus', 'print_config', 'backend'
X
xiongkun 已提交
77
    ]
78
    deprecated_options = [
79
        'selected_devices', 'started_port', 'cluster_node_ips', 'node_ip',
80
        'use_paddlecloud'
81 82 83
    ]
    for key in options:
        if key not in supported_options:
84 85 86 87 88 89 90 91 92 93 94
            if key in deprecated_options:
                warnings.warn(
                    "The config option (%s) of `paddle.distributed.spawn` is deprecated. "
                    "Please use the latest config options stated in the `spawn` API documentation."
                    % key, DeprecationWarning)
            else:
                raise ValueError(
                    "The config option (%s) of `paddle.distributed.spawn` is not supported."
                    % key)


95 96 97 98 99 100
def _get_default_nprocs():
    device = get_device()
    if 'gpu' in device:
        return core.get_cuda_device_count()
    elif 'xpu' in device:
        return core.get_xpu_device_count()
101 102
    elif 'mlu' in device:
        return core.get_mlu_device_count()
X
xiongkun 已提交
103 104 105 106
    elif 'cpu' in device:
        return multiprocessing.cpu_count()
    else:
        raise RuntimeError(
107 108
            "`paddle.distributed.spawn` does not support parallel training on device `{}` now."
            .format(device))
X
xiongkun 已提交
109 110 111 112 113 114 115 116


def _get_default_backend():
    device = get_device()
    if 'gpu' in device:
        return 'nccl'
    elif 'xpu' in device:
        return 'bkcl'
117 118
    elif 'mlu' in device:
        return 'cncl'
X
xiongkun 已提交
119 120
    elif 'cpu' in device:
        return 'gloo'
121 122
    else:
        raise RuntimeError(
123 124
            "`paddle.distributed.spawn` does not support parallel training on device `{}` now."
            .format(device))
125 126


127 128 129 130 131 132 133 134
def _get_node_ip(ips):
    node_ip = None
    node_ips = [x.strip() for x in ips.split(',')]
    if len(node_ips) == 1:
        node_ip = node_ips[0]
    else:
        _, node_ip = get_host_name_ip()
    return node_ip
135 136


137
def _get_subprocess_env_list(nprocs, options):
138 139 140
    # NOTE (xiongkun03) Why put backend deduction  here ?
    # Becase _get_subprocess_env_list is used by many testcases.
    # So for campability, we put backend deduction here
X
xiongkun 已提交
141 142 143 144 145 146 147

    # logic for handle backend option
    if 'backend' not in options or options['backend'] == 'auto':
        options['backend'] = _get_default_backend()
    check_backend(options['backend'])
    block_windows_and_macos(options['backend'])

148 149 150 151 152 153
    # contruct processes env list
    processes_env_list = []

    # get args from kwargs
    args = ParallelEnvArgs()

154 155
    # deal with `ips`
    args.cluster_node_ips = options.get('ips', None)
156
    if args.cluster_node_ips is None:
157 158 159
        args.cluster_node_ips = options.get('cluster_node_ips', None)
        if args.cluster_node_ips is None:
            args.cluster_node_ips = "127.0.0.1"
160

161 162
    # deal with `gpus` or `xpus`
    # set default selected devices(gpus or xpus)
163
    # e.g. if the nprocs is 4, the selected gpus is "0,1,2,3"
164 165 166
    # NOTE(chenweihang): [ why not use FLAGS_selected_gpus or FLAGS_selected_xpus directly? ]
    # because the FLAGS_selected_gpus or FLAGS_selected_xpus may be used in other place,
    # if we set FLAGS_selected_gpus or FLAGS_selected_xpus to be `0,1,2,3`, it may cause error
167
    # when using `ParallelEnv`
168
    # NOTE(chenweihang): use absolute gpu or xpu card id
X
xiongkun 已提交
169
    if options['backend'] == 'nccl':
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        args.selected_devices = options.get('gpus', None)
        if args.selected_devices is None:
            args.selected_devices = options.get('selected_devices', None)
        env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
        if env_devices is None or env_devices == "":
            env_devices_list = [
                str(x) for x in six.moves.range(core.get_cuda_device_count())
            ]
        else:
            env_devices_list = env_devices.split(',')
        if args.selected_devices is None:
            if len(env_devices_list) < nprocs:
                raise RuntimeError(
                    "the number of visible devices(%d) is less than the number "
                    "of spawn processes(%d), please ensure that the correct "
                    "`nprocs` argument is passed or the environment variable "
                    "`CUDA_VISIBLE_DEVICES` is correctly configured." %
                    (len(env_devices_list), nprocs))
            args.selected_devices = ",".join(
                [str(env_devices_list[x]) for x in range(0, nprocs)])
        else:
            selected_device_list = args.selected_devices.split(',')
            if len(selected_device_list) != nprocs:
                raise ValueError(
                    "The number of selected devices(%s) is not equal to "
                    "the number of spawn processes(%d), please ensure that the "
                    "correct `nprocs` and `gpus` arguments are passed." %
                    (len(selected_device_list), nprocs))
            for card_id in selected_device_list:
                if card_id not in env_devices_list:
                    raise ValueError("The selected gpu card %s cannot found in "
                                     "CUDA_VISIBLE_DEVICES (%s)." %
                                     (card_id, ",".join(env_devices_list)))

X
xiongkun 已提交
204
    elif options['backend'] == 'bkcl':
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
        args.selected_devices = options.get('xpus', None)
        if args.selected_devices is None:
            args.selected_devices = options.get('selected_devices', None)
        env_devices = os.getenv("XPU_VISIBLE_DEVICES", None)
        if env_devices is None or env_devices == "":
            env_devices_list = [
                str(x) for x in six.moves.range(core.get_xpu_device_count())
            ]
        else:
            env_devices_list = env_devices.split(',')
        if args.selected_devices is None:
            if len(env_devices_list) < nprocs:
                raise RuntimeError(
                    "the number of visible devices(%d) is less than the number "
                    "of spawn processes(%d), please ensure that the correct "
                    "`nprocs` argument is passed or the environment variable "
                    "`XPU_VISIBLE_DEVICES` is correctly configured." %
                    (len(env_devices_list), nprocs))
            args.selected_devices = ",".join(
                [str(env_devices_list[x]) for x in range(0, nprocs)])
        else:
            selected_device_list = args.selected_devices.split(',')
            if len(selected_device_list) != nprocs:
                raise ValueError(
                    "The number of selected devices(%s) is not equal to "
                    "the number of spawn processes(%d), please ensure that the "
                    "correct `nprocs` and `xpus` arguments are passed." %
                    (len(selected_device_list), nprocs))
            for card_id in selected_device_list:
                if card_id not in env_devices_list:
                    raise ValueError("The selected xpu card %s cannot found in "
                                     "XPU_VISIBLE_DEVICES (%s)." %
                                     (card_id, ",".join(env_devices_list)))
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
    elif options['backend'] == 'cncl':
        args.selected_devices = options.get('mlus', None)
        if args.selected_devices is None:
            args.selected_devices = options.get('selected_devices', None)
        env_devices = os.getenv("MLU_VISIBLE_DEVICES", None)
        if env_devices is None or env_devices == "":
            env_devices_list = [
                str(x) for x in six.moves.range(core.get_mlu_device_count())
            ]
        else:
            env_devices_list = env_devices.split(',')
        if args.selected_devices is None:
            if len(env_devices_list) < nprocs:
                raise RuntimeError(
                    "the number of visible devices(%d) is less than the number "
                    "of spawn processes(%d), please ensure that the correct "
                    "`nprocs` argument is passed or the environment variable "
                    "`MLU_VISIBLE_DEVICES` is correctly configured." %
                    (len(env_devices_list), nprocs))
            args.selected_devices = ",".join(
                [str(env_devices_list[x]) for x in range(0, nprocs)])
        else:
            selected_device_list = args.selected_devices.split(',')
            if len(selected_device_list) != nprocs:
                raise ValueError(
                    "The number of selected devices(%s) is not equal to "
                    "the number of spawn processes(%d), please ensure that the "
                    "correct `nprocs` and `mlus` arguments are passed." %
                    (len(selected_device_list), nprocs))
            for card_id in selected_device_list:
                if card_id not in env_devices_list:
                    raise ValueError("The selected mlu card %s cannot found in "
                                     "MLU_VISIBLE_DEVICES (%s)." %
                                     (card_id, ",".join(env_devices_list)))
X
xiongkun 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
    elif options['backend'] == 'gloo':
        # TODO check gpu / xpu flag must not exist
        warnings.warn(
            "Your model will be trained under CPUONLY mode by using GLOO,"
            "because CPUPlace is specified manually or your installed PaddlePaddle only support CPU Device."
        )
        args.paddle_cpuonly = True
        args.selected_devices = None
        args.ips = args.cluster_node_ips
        assert options.get(
            'use_paddlecloud',
            None) is None, "CPUONLY spawn doesn't support use paddle cloud"
        assert len(
            args.cluster_node_ips.split(',')
        ) <= 1, "CPUONLY spawn only support single trainer, that is len(ips)=1, but got %s."
        assert _get_trainers_num(
        ) == 1, "CPUONLY spawn doesn't support multi-trainer"
289

290 291 292 293 294
    # set other inner args
    args.node_ip = options.get('node_ip', None)
    if args.node_ip is None:
        args.node_ip = _get_node_ip(args.cluster_node_ips)

295 296
    args.started_port = options.get('started_port', None)

297 298 299 300 301
    args.use_paddlecloud = options.get('use_paddlecloud', None)
    if args.use_paddlecloud is None:
        args.use_paddlecloud = use_paddlecloud()

    # get cluster and pod config
X
xiongkun 已提交
302 303 304 305 306 307
    if options['backend'] == 'gloo':
        devices_per_proc = [x for x in range(0, nprocs)]
        cluster, pod = get_cluster_from_args(args, DeviceMode.CPU,
                                             devices_per_proc)
    else:
        cluster, pod = get_cluster_and_pod(args)
308 309 310

    # prepare subprocess env list
    for trainer in pod.trainers:
X
xiongkun 已提交
311 312
        processes_env_list.append(
            _prepare_trainer_env(cluster, trainer, options['backend']))
313

314 315
    # [Debug] print config
    args.print_config = options.get('print_config', False)
316 317 318 319 320 321 322
    if args.print_config:
        _print_arguments(args)

    return processes_env_list


def _remove_risky_env():
323
    # remove useless env vars
324 325 326 327 328
    # no copy, each process will hold env vars itself
    os.environ.pop("http_proxy", None)
    os.environ.pop("https_proxy", None)


X
xiongkun 已提交
329
def _set_trainer_env(env_dict, backend):
330
    # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ]
331 332
    # When the child process starts, it will inherit the configuration of the
    # main process and set the FLAGS once, but the environment variable has
333
    # not been set at this time, which leads to the FLAGS_selected_gpus or FLAGS_selected_xpus
334
    # is keep same with mainprocess(usually empty), so manually update the flags here
X
xiongkun 已提交
335 336 337 338 339

    # NOTE(xiongkun): why put backend here?  because if gloo, we shouldn't set FLAGS_selectedXXX
    #

    if backend == 'nccl':
340
        set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']})
X
xiongkun 已提交
341
    elif backend == 'bkcl':
342
        set_flags({'FLAGS_selected_xpus': env_dict['FLAGS_selected_xpus']})
343 344
    elif backend == 'cncl':
        set_flags({'FLAGS_selected_mlus': env_dict['FLAGS_selected_mlus']})
345
    else:
346 347
        #NOTE(xiongkun) why not raise Error ?
        # So far, we added support for CPU parallel, and will be applied when paddle is not
X
xiongkun 已提交
348 349 350
        # compiled with cuda or xp. just do nothing.
        pass

351 352 353 354
    for var_name in env_dict:
        os.environ[var_name] = env_dict[var_name]


X
xiongkun 已提交
355
def _func_wrapper(func, args, error_queue, return_queue, env_dict, backend):
356 357 358
    try:
        # config subprocess environment variables
        _remove_risky_env()
X
xiongkun 已提交
359
        _set_trainer_env(env_dict, backend)
360 361 362 363 364 365 366 367 368 369 370 371 372
        # execute function
        result = func(*args)
        # record function return value
        return_queue.put(result)
    except KeyboardInterrupt:
        pass
    except Exception:
        import traceback
        error_queue.put(traceback.format_exc())
        sys.exit(1)


class MultiprocessContext(object):
373

374 375 376
    def __init__(self, processes, error_queues, return_queues):
        _py_supported_check()
        self.error_queues = error_queues
377 378 379 380
        # NOTE(chenweihang): The `spawn` method is mainly used
        # to wrap the outermost execution function of the program for
        # parallel execution. Generally, the return value is not concerned,
        # but if the user needs to obtain the return value, users can get
381 382 383 384 385 386 387 388 389 390 391 392
        # the return result of each process from context.return_queues
        self.return_queues = return_queues
        self.processes = processes
        self.sentinels = {
            process.sentinel: index
            for index, process in enumerate(processes)
        }

    def join(self, timeout=None):
        if len(self.sentinels) == 0:
            return True

393 394
        ready = multiprocessing.connection.wait(self.sentinels.keys(),
                                                timeout=timeout)
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422

        error_index = None
        for sentinel in ready:
            index = self.sentinels.pop(sentinel)
            process = self.processes[index]
            process.join()
            if process.exitcode != 0:
                error_index = index
                break

        if error_index is None:
            return len(self.sentinels) == 0

        for process in self.processes:
            if process.is_alive():
                process.terminate()
            process.join()

        self._throw_exception(error_index)

    def _throw_exception(self, error_index):
        if self.error_queues[error_index].empty():
            exitcode = self.processes[error_index].exitcode
            if exitcode < 0:
                name = signal.Signals(-exitcode).name
                raise Exception("Process %d terminated with signal %s." %
                                (error_index, name))
            else:
C
Chen Weihang 已提交
423 424
                raise Exception("Process %d terminated with exit code %d." %
                                (error_index, exitcode))
425 426 427 428 429 430 431 432 433 434 435 436 437

        original_trace = self.error_queues[error_index].get()
        msg = "\n\n----------------------------------------------\n" \
              "Process %d terminated with the following error:\n" \
              "----------------------------------------------\n\n" % error_index
        msg += original_trace
        raise Exception(msg)


def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
    """
    Start multiple processes with ``spawn`` method for parallel training.

438
    .. note::
439 440 441
        ``spawn`` now only supports GPU or XPU or MLU collective mode. The collective mode
        of GPU and XPU and MLU cannot be started at the same time, so the option `gpus` and
        `xpus` and 'mlus' cannot be configured at the same time.
442

443 444 445 446
    Args:
        func (function): The target function is called by spawned process.
            This function need to be able to pickled, so it must be defined
            at the top level of a module.
447
        args (list|tuple, optional): Arguments passed to ``func``.
448
        nprocs (int, optional): Number of processed to start. Default: -1.
C
Chen Weihang 已提交
449 450 451 452 453
            when nprocs is -1, the available device will be obtained from
            the environment variable when the model is executed: If use GPU,
            the currently available device ID is obtained from the environment
            variable CUDA_VISIBLE_DEVICES; If use XPU, the currently available
            device ID is obtained from the environment variable XPU_VISIBLE_DEVICES.
454 455 456
        join (bool, optional): Perform a blocking join on all spawned processes.
            Default: True.
        daemon (bool, optional): The spawned processes' daemon flag. Default: False.
C
Chen Weihang 已提交
457 458 459 460 461 462 463 464 465 466 467
        **options(dict, optional): Other initial parallel execution environment
            configuration options. The following options are currently supported:
            (1) start_method (string): the way to start a process.
            The start method can be ``spawn`` , ``fork`` , ``forkserver`` .
            Because the CUDA runtime does not support the ``fork`` start method,
            when use CUDA in subprocesses, we should start process by ``spawn``
            or ``forkserver`` method. Default: "spawn" ;
            (2) gpus (string): The training process will run on the
            selected gpus, such as "0,1,2,3". Default: None;
            (3) xpus (string): The training process will run on the
            selected xpus, such as "0,1,2,3". Default: None;
468 469 470
            (4) mlus (string): The training process will run on the
            selected mlus, such as "0,1,2,3". Default: None;
            (5) ips (string): Paddle cluster nodes ips, such as
C
Chen Weihang 已提交
471
            "192.168.0.16,192.168.0.17". Default: "127.0.0.1" .
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490

    Returns:
        ``MultiprocessContext`` object, it hold the spawned processes.

    Examples:
        .. code-block:: python

            from __future__ import print_function

            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
C
Chen Weihang 已提交
491

492 493 494
                def forward(self, x):
                    return self._linear2(self._linear1(x))

C
Chen Weihang 已提交
495
            def train(print_result=False):
496
                # 1. initialize parallel environment
497 498
                group = dist.init_parallel_env()
                process_group = group.process_group if group else None
499

500
                # 2. create data parallel layer & optimizer
501
                layer = LinearNet()
502
                dp_layer = paddle.DataParallel(layer, group = process_group)
503 504 505 506 507

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

508
                # 3. run layer
509 510 511 512
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
C
Chen Weihang 已提交
513

514 515
                if print_result is True:
                    print("loss:", loss.numpy())
C
Chen Weihang 已提交
516

517 518 519 520 521
                loss.backward()

                adam.step()
                adam.clear_grad()

C
Chen Weihang 已提交
522 523 524
            # Usage 1: only pass function.
            # If your training method no need any argument, and
            # use all visible devices for parallel training.
525 526 527 528
            if __name__ == '__main__':
                dist.spawn(train)

            # Usage 2: pass function and arguments.
C
Chen Weihang 已提交
529
            # If your training method need some arguments, and
530 531 532 533 534
            # use all visible devices for parallel training.
            if __name__ == '__main__':
                dist.spawn(train, args=(True,))

            # Usage 3: pass function, arguments and nprocs.
C
Chen Weihang 已提交
535
            # If your training method need some arguments, and
536 537
            # only use part of visible devices for parallel training.
            # If your machine hold 8 cards {0,1,2,3,4,5,6,7},
C
Chen Weihang 已提交
538
            # this case will use cards {0,1}; If you set
539 540 541 542 543
            # CUDA_VISIBLE_DEVICES=4,5,6,7, this case will use
            # cards {4,5}
            if __name__ == '__main__':
                dist.spawn(train, args=(True,), nprocs=2)

544
            # Usage 4: pass function, arguments, nprocs and gpus.
C
Chen Weihang 已提交
545
            # If your training method need some arguments, and
546
            # only use part of visible devices for parallel training,
C
Chen Weihang 已提交
547
            # but you can't set your machine's environment variable
548
            # CUDA_VISIBLE_DEVICES, such as it is None or all cards
C
Chen Weihang 已提交
549
            # {0,1,2,3,4,5,6,7}, you can pass `gpus` to
550 551 552
            # select the GPU cards you want to use. For example,
            # this case will use cards {4,5} if your machine hold 8 cards.
            if __name__ == '__main__':
553
                dist.spawn(train, args=(True,), nprocs=2, gpus='4,5')
554 555 556
    """
    # NOTE(chenweihang): [ why only supports python3.4+ ? ]
    # Python supported setting the child process startup method
557 558
    # since 3.4. The previous version can only use the default startup
    # method, while the default startup method of Unix is fork, which
559 560 561
    # cannot support CUDA runtime multi-process
    _py_supported_check()

562
    # Give an error hint when the users enter a configuration option
563 564 565
    # that does not exist
    _options_valid_check(options)

566 567
    # get default nprocs
    if nprocs == -1:
568
        nprocs = _get_default_nprocs()
569 570

    # NOTE(chenweihang): [ why need get cluster info before run? ]
571 572
    # when using `paddle.distributed.spawn` start parallel training,
    # we should get cluster info before starting subprocess, and pass
573 574 575 576 577
    # correct info to each subprocess
    procs_env_list = _get_subprocess_env_list(nprocs, options)

    # start processes
    # NOTE(chenweihang): [ why default start method is spawn? ]
578 579
    # The CUDA runtime does not support the fork start method,
    # either the spawn or forkserver start method are required
580 581 582 583 584 585 586 587 588 589 590 591
    # to use CUDA in subprocesses.
    start_method = options.get('start_method', None)
    if start_method is None:
        start_method = 'spawn'
    mp = multiprocessing.get_context(start_method)

    error_queues = []
    return_queues = []
    processes = []
    for i in range(nprocs):
        error_queue = mp.SimpleQueue()
        return_queue = mp.SimpleQueue()
592 593 594
        process = mp.Process(target=_func_wrapper,
                             args=(func, args, error_queue, return_queue,
                                   procs_env_list[i], options['backend']))
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
        process.daemon = daemon
        process.start()
        error_queues.append(error_queue)
        return_queues.append(return_queue)
        processes.append(process)

    context = MultiprocessContext(processes, error_queues, return_queues)
    if not join:
        return context

    # loop until all process end
    while not context.join():
        pass

    # finally return context
    return context