system_helper.py 1.8 KB
Newer Older
L
liqingping 已提交
1
"""
L
liqingping 已提交
2
Copyright 2020 OpenDILab. All Rights Reserved
L
liqingping 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
"""
import os
import socket
import time
import uuid
from typing import Optional, Any
from threading import Thread
from contextlib import closing


def get_ip() -> str:
    r"""
    Overview:
        get the ip(host) of socket
    Returns:
        - ip(:obj:`str`): the corresponding ip
    """
    # beware: return 127.0.0.1 on some slurm nodes
    myname = socket.getfqdn(socket.gethostname())
    myaddr = socket.gethostbyname(myname)

    return myaddr


def get_pid() -> int:
    r"""
    Overview:
        os.getpid
    """
    return os.getpid()


def get_task_uid() -> str:
    r"""
    Overview:
        get the slurm job_id, pid and uid
    """
    return os.getenv('SLURM_JOB_ID', 'PID{pid}UUID{uuid}'.format(
        pid=str(get_pid()),
        uuid=str(uuid.uuid1()),
    )) + '_' + str(time.time())


class PropagatingThread(Thread):
    """
    Overview:
        Subclass of Thread that propagates execution exception in the thread to the caller
    Examples:
        >>> def func():
        >>>     raise Exception()
        >>> t = PropagatingThread(target=func, args=())
        >>> t.start()
        >>> t.join()
    """

    def run(self) -> None:
        self.exc = None
        try:
            self.ret = self._target(*self._args, **self._kwargs)
        except BaseException as e:
            self.exc = e

    def join(self) -> Any:
        super(PropagatingThread, self).join()
        if self.exc:
            raise RuntimeError('Exception in thread({})'.format(id(self))) from self.exc
        return self.ret


def find_free_port(host: str) -> int:
    r"""
    Overview:
        Look up the free port list and return one
    """
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        return s.getsockname()[1]