diff --git a/python_module/megengine/distributed/__init__.py b/python_module/megengine/distributed/__init__.py index fb6e8033a72b107447336c816304180781446756..63974cd37467c2f8f62cfd5e83c1cedf2c25ac0a 100644 --- a/python_module/megengine/distributed/__init__.py +++ b/python_module/megengine/distributed/__init__.py @@ -18,6 +18,7 @@ from .functional import ( ) from .util import ( get_backend, + get_free_ports, get_master_ip, get_master_port, get_rank, diff --git a/python_module/megengine/distributed/util.py b/python_module/megengine/distributed/util.py index 52248d303fd77fee5df203a656ff44409d53be98..115ae326fd7b33963021bb8082297a9aa7ff30f2 100644 --- a/python_module/megengine/distributed/util.py +++ b/python_module/megengine/distributed/util.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools +import socket from typing import Callable, Optional import megengine._internal as mgb @@ -110,7 +111,7 @@ def synchronized(func: Callable): Specifically, we use this to prevent data race during hub.load""" @functools.wraps(func) - def _(*args, **kwargs): + def wrapper(*args, **kwargs): if not is_distributed(): return func(*args, **kwargs) @@ -118,4 +119,19 @@ def synchronized(func: Callable): group_barrier() return ret - return _ + return wrapper + + +def get_free_ports(num: Optional[int] = 1) -> int: + """Get one or more free ports. + Return an integer if num is 1, otherwise return a list of integers + """ + socks, ports = [], [] + for i in range(num): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", 0)) + socks.append(sock) + ports.append(sock.getsockname()[1]) + for sock in socks: + sock.close() + return ports[0] if num == 1 else ports