util.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import socket
11
from typing import Callable, List, Optional
12 13 14 15 16 17 18 19 20 21

import megengine._internal as mgb

from ..core import set_default_device

_master_ip = None
_master_port = 0
_world_size = 0
_rank = 0
_backend = None
22
_group_id = 0
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


def init_process_group(
    master_ip: str,
    master_port: int,
    world_size: int,
    rank: int,
    dev: int,
    backend: Optional[str] = "nccl",
) -> None:
    """Initialize the distributed process group, and also specify the device used in the current process.

    :param master_ip: IP address of the master node.
    :param master_port: Port available for all processes to communicate.
    :param world_size: Total number of processes participating in the job.
    :param rank: Rank of the current process.
    :param dev: The GPU device id to bind this process to.
    :param backend: Communicator backend, currently support 'nccl' and 'ucx'
    """
    global _master_ip  # pylint: disable=global-statement
    global _master_port  # pylint: disable=global-statement
    global _world_size  # pylint: disable=global-statement
    global _rank  # pylint: disable=global-statement
    global _backend  # pylint: disable=global-statement
47
    global _group_id  # pylint: disable=global-statement
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    if not isinstance(master_ip, str):
        raise TypeError("Expect type str but got {}".format(type(master_ip)))
    if not isinstance(master_port, int):
        raise TypeError("Expect type int but got {}".format(type(master_port)))
    if not isinstance(world_size, int):
        raise TypeError("Expect type int but got {}".format(type(world_size)))
    if not isinstance(rank, int):
        raise TypeError("Expect type int but got {}".format(type(rank)))
    if not isinstance(backend, str):
        raise TypeError("Expect type str but got {}".format(type(backend)))

    _master_ip = master_ip
    _master_port = master_port
    _world_size = world_size
    _rank = rank
    _backend = backend
65
    _group_id = 0
66 67 68 69

    set_default_device(mgb.comp_node("gpu" + str(dev)))

    if rank == 0:
70 71
        _master_port = mgb.config.create_mm_server("0.0.0.0", master_port)
        if _master_port == -1:
72
            raise Exception("Failed to start server on port {}".format(master_port))
73 74
    else:
        assert master_port > 0, "master_port must be specified for non-zero rank"
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106


def is_distributed() -> bool:
    """Return True if the distributed process group has been initialized"""
    return _world_size is not None and _world_size > 1


def get_master_ip() -> str:
    """Get the IP address of the master node"""
    return str(_master_ip)


def get_master_port() -> int:
    """Get the port of the rpc server on the master node"""
    return _master_port


def get_world_size() -> int:
    """Get the total number of processes participating in the job"""
    return _world_size


def get_rank() -> int:
    """Get the rank of the current process"""
    return _rank


def get_backend() -> str:
    """Get the backend str"""
    return str(_backend)


107 108 109 110 111 112 113
def get_group_id() -> int:
    """Get group id for collective communication"""
    global _group_id
    _group_id += 1
    return _group_id


114 115 116 117 118 119 120 121 122 123
def group_barrier() -> None:
    """Block until all ranks in the group reach this barrier"""
    mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank)


def synchronized(func: Callable):
    """Decorator. Decorated function will synchronize when finished.
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
124
    def wrapper(*args, **kwargs):
125 126 127 128 129 130 131
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

132 133 134
    return wrapper


135
def get_free_ports(num: int) -> List[int]:
136 137 138 139 140 141 142 143 144 145
    """Get one or more free ports.
    """
    socks, ports = [], []
    for i in range(num):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.bind(("", 0))
        socks.append(sock)
        ports.append(sock.getsockname()[1])
    for sock in socks:
        sock.close()
146
    return ports