distribute.py 3.5 KB
Newer Older
1
import atexit
2
import warnings
3 4 5 6 7 8 9 10 11 12
from typing import Any, Optional, Sequence, Union

from .lib import xla_client as xc

xla_extention = xc._xla
xe = xla_extention


class State:
    process_id: int = 0
13 14
    ip: str = None
    port: int = None
15 16 17 18 19 20 21
    service: Optional[Any] = None
    client: Optional[Any] = None
    preemption_sync_manager: Optional[Any] = None
    visible_devices: Optional[str] = "all"

    def initialize(
        self,
22 23
        ip: str,
        port: int,
24 25 26 27
        num_processes: int,
        process_id: int,
        local_device_ids: Optional[Union[int, Sequence[int]]] = None,
    ):
28
        coordinator_address = ip + ":" + str(port)
29 30 31 32 33 34 35 36 37
        if local_device_ids is None:
            local_device_ids = [process_id]
        elif isinstance(local_device_ids, int):
            local_device_ids = [local_device_ids]
        else:
            local_device_ids = list(local_device_ids)

        assert local_device_ids == [process_id], f"{local_device_ids} .vs {process_id}"

38 39
        self.ip = ip
        self.port = port
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        self.visible_devices = ",".join(str(x) for x in local_device_ids)
        self.process_id = process_id

        if process_id == 0:
            if self.service is not None:
                raise RuntimeError("distributed.initialize should only be called once.")
            self.service = xe.get_distributed_runtime_service(
                coordinator_address, num_processes, use_coordination_service=True
            )

        if self.client is not None:
            raise RuntimeError("distributed.initialize should only be called once.")

        # Set init_timeout to 5 min to leave time for all the processes to connect
        self.client = xe.get_distributed_runtime_client(
            coordinator_address,
            process_id,
            use_coordination_service=True,
            init_timeout=300,
        )
        self.client.connect()
        self.initialize_preemption_sync_manager()

    def shutdown(self):
        if self.client:
            self.client.shutdown()
            self.client = None
        if self.service:
            self.service.shutdown()
            self.service = None
        if self.preemption_sync_manager:
            self.preemption_sync_manager = None

    def initialize_preemption_sync_manager(self):
        if self.preemption_sync_manager is not None:
            raise RuntimeError(
                "Preemption sync manager should only be initialized once."
            )
        self.preemption_sync_manager = xe.create_preemption_sync_manager()
        self.preemption_sync_manager.initialize(self.client)


global_state = State()


def initialize(
86 87
    ip: str,
    port: int,
88 89 90 91
    num_processes: int,
    process_id: int,
    local_device_ids: Optional[Union[int, Sequence[int]]] = None,
):
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    ip = "127.0.0.1" if ip == "localhost" else ip
    if global_state.service == None and global_state.client == None:
        global_state.initialize(ip, port, num_processes, process_id, local_device_ids)
        atexit.register(shutdown)
    else:
        assert (
            global_state.client != None
        ), "global_state.client should not be None if server is created"
        if global_state.ip == ip and global_state.port == port:
            return
        else:
            msg = (
                f"xla distribute server/client have been created on {global_state.ip}:{global_state.port}. "
                f"so ignore the request to create on {ip}:{port}"
            )
            warnings.warn(msg, category=RuntimeWarning)
108 109 110 111


def shutdown():
    global_state.shutdown()