diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index a6c7c05ae95347afc2615257cc431b19abdbee4a..6cf39fa00fb5039ff735c3a609be6c68b5abb569 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -29,8 +29,8 @@ def launcher(func): def wrapper(*args, **kwargs): master_ip = "localhost" - port = get_free_ports(1)[0] - server = Server(port) + server = Server() + port = server.py_server_port procs = [] for rank in range(n_gpus): @@ -41,9 +41,18 @@ def launcher(func): p.start() procs.append(p) - for rank in range(n_gpus): - procs[rank].join() - code = procs[rank].exitcode - assert code == 0, "subprocess {} exit with code {}".format(rank, code) + ranks = [rank for rank in range(n_gpus)] + + while len(ranks) > 0: + left = [] + for rank in ranks: + procs[rank].join(1) + code = procs[rank].exitcode + assert ( + code == 0 or code == None + ), "subprocess {} exit with code {}".format(rank, code) + if code == None: + left.append(rank) + ranks = left return wrapper diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index c9ab3177431654817789cca4cc8c90a080e90b30..6ab00a8b18aa65f18b97f729c22d1840ff2400ea 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -10,6 +10,7 @@ import threading import time from collections import defaultdict from functools import partial +from queue import Queue from socketserver import ThreadingMixIn from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer @@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass -def start_server(py_server_port, mm_server_port): +def start_server(py_server_port, mm_server_port, queue): """ Start python distributed server and multiple machine server. @@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port): """ server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server.register_instance(Methods(mm_server_port)) + _, port = server.server_address + queue.put(port) server.serve_forever() @@ -152,15 +155,14 @@ class Server: :param port: python server port. """ - def __init__(self, port): - self.py_server_port = get_free_ports(1)[0] if port == 0 else port + def __init__(self, port=0): self.mm_server_port = create_mm_server("0.0.0.0", 0) + q = Queue() self.proc = threading.Thread( - target=start_server, - args=(self.py_server_port, self.mm_server_port), - daemon=True, + target=start_server, args=(port, self.mm_server_port, q), daemon=True, ) self.proc.start() + self.py_server_port = q.get() class Client: