launcher.py 5.3 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
import functools
10
import multiprocessing as mp
11
import os
12
import queue
13

M
Megvii Engine Team 已提交
14
from .. import _exit
15
from ..core._imperative_rt.core2 import full_sync
16
from ..device import get_device_count
17
from ..logger import get_logger
18
from .group import _set_machine_ranks, group_barrier, init_process_group
19
from .helper import _check_device_initialized
20
from .server import Client, Server
21

22 23 24 25
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
    "subprocess exited with code 0 but did not return a value"
)

26

27
def _run_wrapped(
28 29 30 31 32 33 34
    func,
    is_multimachine,
    master_ip,
    port,
    world_size,
    rank,
    dev,
35
    device_type,
36 37
    args,
    kwargs,
38
    backend,
39
    queue: mp.Queue,
40
    machine_ranks: list,
41
):
M
Megvii Engine Team 已提交
42
    """Init distributed process group and run wrapped function."""
43
    _check_device_initialized(device_type, dev)
44
    init_process_group(
45 46 47 48 49
        master_ip=master_ip,
        port=port,
        world_size=world_size,
        rank=rank,
        device=dev,
50
        backend=backend,
51
        device_type=device_type,
52
    )
53 54
    # set NCCL_LAUNCH_MODE to avoid deadlock
    os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
55
    _set_machine_ranks(machine_ranks)
56 57
    if is_multimachine:
        group_barrier()
58 59
    ret = func(*args, **kwargs)
    queue.put((dev, ret))
60
    full_sync()
61 62
    if is_multimachine:
        group_barrier()
M
Megvii Engine Team 已提交
63
    _exit(0)
64 65


66 67
class launcher:
    """Decorator for launching multiple processes in single-machine multi-gpu training.
68

69 70 71 72 73 74
    :param func: the function you want to launch in distributed mode.
    :param n_gpus: how many devices each node.
    :param world_size: how many devices totally.
    :param rank_start: start number for rank.
    :param master_ip: ip address for master node (where the rank 0 is).
    :param port: server port for distributed server.
75
    :param backend: set default collective communication backend.
76
    """
77

78 79 80 81
    def __new__(cls, *args, **kwargs):
        if not args:
            return functools.partial(cls, **kwargs)
        return super().__new__(cls)
82

83 84 85 86 87 88 89 90
    def __init__(
        self,
        func,
        n_gpus=None,
        world_size=None,
        rank_start=0,
        master_ip="localhost",
        port=0,
91
        device_type="xpu",
92
        backend="auto",
93 94
    ):
        self.func = func
95
        self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type)
96 97 98 99
        self.world_size = world_size if world_size is not None else self.n_gpus
        self.rank_start = rank_start
        self.master_ip = master_ip
        self.port = port
100
        self.device_type = device_type
101
        self.backend = backend
102 103 104 105 106 107
        # master node create server
        if self.rank_start == 0:
            self.server = Server(self.port)
            self.port = self.server.py_server_port
        else:
            assert self.port != 0, "you have to assign a port for distributed server"
108

109
    def __call__(self, *args, **kwargs):
110
        procs = []
111 112
        queue = mp.Queue(self.n_gpus)
        results = [None] * self.n_gpus
113
        machine_ranks = [i + self.rank_start for i in range(self.n_gpus)]
114
        for dev in range(self.n_gpus):
115 116
            p = mp.Process(
                target=_run_wrapped,
117 118 119 120 121 122 123 124
                args=(
                    self.func,
                    self.world_size > self.n_gpus,
                    self.master_ip,
                    self.port,
                    self.world_size,
                    dev + self.rank_start,
                    dev,
125
                    self.device_type,
126 127
                    args,
                    kwargs,
128
                    self.backend,
129
                    queue,
130
                    machine_ranks,
131
                ),
132 133 134
            )
            p.start()
            procs.append(p)
135

136
        devs = list(range(self.n_gpus))
137

138 139 140 141 142
        def terminate():
            for dev in devs:
                procs[dev].terminate()
            devs.clear()

143
        result_count = 0
144
        while len(devs) > 0:
145
            left = []
146
            # check all processes in one second
147 148 149 150
            time_to_wait = 1.0 / len(devs)
            for dev in devs:
                procs[dev].join(time_to_wait)
                code = procs[dev].exitcode
151 152
                # terminate processes if one of them has failed
                if code != 0 and code != None:
153
                    terminate()
154 155
                assert (
                    code == 0 or code == None
156
                ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
157
                if code == None:
158
                    left.append(dev)
159 160 161 162 163

                # DO NOT delete it, multiprocess.Queue has small buffer
                # fetch data early to avoid dead lock
                if not queue.empty():
                    result_count += 1
164 165
                    dev, ret = queue.get_nowait()
                    results[dev] = ret
166
            devs = left
167

168 169 170 171 172 173 174 175
        while not queue.empty():
            result_count += 1
            dev, ret = queue.get_nowait()
            results[dev] = ret

        if result_count < self.n_gpus:
            get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)

176
        return results