提交 0adf49b1 编写于 作者: M Megvii Engine Team

fix(mge/distributed): fix deadlock by mixing thread and fork

GitOrigin-RevId: c138cb9c280aeb37d83d2ca31549df0661ebbbc9
上级 ae8b38f6
...@@ -12,7 +12,7 @@ import multiprocessing as mp ...@@ -12,7 +12,7 @@ import multiprocessing as mp
from ..core._imperative_rt.core2 import sync from ..core._imperative_rt.core2 import sync
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
from .server import Server from .server import Client, Server
from .util import get_free_ports from .util import get_free_ports
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
...@@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): ...@@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
def _start_server(py_server_port, mm_server_port, queue): def _start_server(py_server_port, queue):
""" """
Start python distributed server and multiple machine server. Start python distributed server and multiple machine server.
...@@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue): ...@@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue):
:param queue: server port will put in this queue, puts exception when process fails. :param queue: server port will put in this queue, puts exception when process fails.
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, port = server.server_address _, py_server_port = server.server_address
queue.put(port) queue.put((py_server_port, mm_server_port))
server.serve_forever() server.serve_forever()
except Exception as e: except Exception as e:
queue.put(e) queue.put(e)
...@@ -160,17 +161,17 @@ class Server: ...@@ -160,17 +161,17 @@ class Server:
""" """
def __init__(self, port=0): def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0) q = mp.Queue()
q = Queue() self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
self.proc = threading.Thread(
target=_start_server, args=(port, self.mm_server_port, q), daemon=True,
)
self.proc.start() self.proc.start()
ret = q.get() ret = q.get()
if isinstance(ret, Exception): if isinstance(ret, Exception):
raise ret raise ret
else: else:
self.py_server_port = ret self.py_server_port, self.mm_server_port = ret
def __del__(self):
self.proc.terminate()
class Client: class Client:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册