提交 0adf49b1 编写于 作者: M Megvii Engine Team

fix(mge/distributed): fix deadlock by mixing thread and fork

GitOrigin-RevId: c138cb9c280aeb37d83d2ca31549df0661ebbbc9
上级 ae8b38f6
......@@ -12,7 +12,7 @@ import multiprocessing as mp
from ..core._imperative_rt.core2 import sync
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .server import Server
from .server import Client, Server
from .util import get_free_ports
......
......@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import threading
import time
from collections import defaultdict
from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
......@@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
def _start_server(py_server_port, mm_server_port, queue):
def _start_server(py_server_port, queue):
"""
Start python distributed server and multiple machine server.
......@@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue):
:param queue: server port will put in this queue, puts exception when process fails.
"""
try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
server.serve_forever()
except Exception as e:
queue.put(e)
......@@ -160,17 +161,17 @@ class Server:
"""
def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread(
target=_start_server, args=(port, self.mm_server_port, q), daemon=True,
)
q = mp.Queue()
self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
self.proc.start()
ret = q.get()
if isinstance(ret, Exception):
raise ret
else:
self.py_server_port = ret
self.py_server_port, self.mm_server_port = ret
def __del__(self):
self.proc.terminate()
class Client:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册