提交 3ecded74 编写于 作者: M Megvii Engine Team

refactor(distributed/server): use port 0 to get available port

GitOrigin-RevId: e367846b9216ea6d5ef7ada698ffba790ba8e1e6
上级 88e918e2
......@@ -29,8 +29,8 @@ def launcher(func):
def wrapper(*args, **kwargs):
master_ip = "localhost"
port = get_free_ports(1)[0]
server = Server(port)
server = Server()
port = server.py_server_port
procs = []
for rank in range(n_gpus):
......@@ -41,9 +41,18 @@ def launcher(func):
p.start()
procs.append(p)
for rank in range(n_gpus):
procs[rank].join()
code = procs[rank].exitcode
assert code == 0, "subprocess {} exit with code {}".format(rank, code)
ranks = [rank for rank in range(n_gpus)]
while len(ranks) > 0:
left = []
for rank in ranks:
procs[rank].join(1)
code = procs[rank].exitcode
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
if code == None:
left.append(rank)
ranks = left
return wrapper
......@@ -10,6 +10,7 @@ import threading
import time
from collections import defaultdict
from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
......@@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
def start_server(py_server_port, mm_server_port):
def start_server(py_server_port, mm_server_port, queue):
"""
Start python distributed server and multiple machine server.
......@@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port):
"""
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
server.serve_forever()
......@@ -152,15 +155,14 @@ class Server:
:param port: python server port.
"""
def __init__(self, port):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port
def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread(
target=start_server,
args=(self.py_server_port, self.mm_server_port),
daemon=True,
target=start_server, args=(port, self.mm_server_port, q), daemon=True,
)
self.proc.start()
self.py_server_port = q.get()
class Client:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册