提交 3ecded74 编写于 作者: M Megvii Engine Team

refactor(distributed/server): use port 0 to get available port

GitOrigin-RevId: e367846b9216ea6d5ef7ada698ffba790ba8e1e6
上级 88e918e2
...@@ -29,8 +29,8 @@ def launcher(func): ...@@ -29,8 +29,8 @@ def launcher(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
master_ip = "localhost" master_ip = "localhost"
port = get_free_ports(1)[0] server = Server()
server = Server(port) port = server.py_server_port
procs = [] procs = []
for rank in range(n_gpus): for rank in range(n_gpus):
...@@ -41,9 +41,18 @@ def launcher(func): ...@@ -41,9 +41,18 @@ def launcher(func):
p.start() p.start()
procs.append(p) procs.append(p)
for rank in range(n_gpus): ranks = [rank for rank in range(n_gpus)]
procs[rank].join()
code = procs[rank].exitcode while len(ranks) > 0:
assert code == 0, "subprocess {} exit with code {}".format(rank, code) left = []
for rank in ranks:
procs[rank].join(1)
code = procs[rank].exitcode
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
if code == None:
left.append(rank)
ranks = left
return wrapper return wrapper
...@@ -10,6 +10,7 @@ import threading ...@@ -10,6 +10,7 @@ import threading
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
...@@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): ...@@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
def start_server(py_server_port, mm_server_port): def start_server(py_server_port, mm_server_port, queue):
""" """
Start python distributed server and multiple machine server. Start python distributed server and multiple machine server.
...@@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port): ...@@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port):
""" """
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
server.serve_forever() server.serve_forever()
...@@ -152,15 +155,14 @@ class Server: ...@@ -152,15 +155,14 @@ class Server:
:param port: python server port. :param port: python server port.
""" """
def __init__(self, port): def __init__(self, port=0):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port
self.mm_server_port = create_mm_server("0.0.0.0", 0) self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread( self.proc = threading.Thread(
target=start_server, target=start_server, args=(port, self.mm_server_port, q), daemon=True,
args=(self.py_server_port, self.mm_server_port),
daemon=True,
) )
self.proc.start() self.proc.start()
self.py_server_port = q.get()
class Client: class Client:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册