提交 4c5141d6 编写于 作者: M Megvii Engine Team

perf(dist): speed up bcast_val

GitOrigin-RevId: 21c4123b09480b425676681a16a50962141b1eda
上级 0c6ee228
...@@ -36,6 +36,7 @@ class Methods: ...@@ -36,6 +36,7 @@ class Methods:
self.dict_barrier_counter = defaultdict(int) self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event) self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False)) self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}
def connect(self): def connect(self):
"""Method for checking connection success.""" """Method for checking connection success."""
...@@ -127,6 +128,23 @@ class Methods: ...@@ -127,6 +128,23 @@ class Methods:
future = self.user_dict[key] future = self.user_dict[key]
return future.get() return future.get()
def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
...@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue): ...@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address _, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port)) queue.put((py_server_port, mm_server_port))
...@@ -185,13 +205,14 @@ class Client: ...@@ -185,13 +205,14 @@ class Client:
self.master_ip = master_ip self.master_ip = master_ip
self.port = port self.port = port
self.connect() self.connect()
self.bcast_dict = defaultdict(lambda: 0)
def connect(self): def connect(self):
"""Check connection success.""" """Check connection success."""
while True: while True:
try: try:
self.proxy = ServerProxy( self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port) "http://{}:{}".format(self.master_ip, self.port), allow_none=True
) )
if self.proxy.connect(): if self.proxy.connect():
break break
...@@ -247,22 +268,17 @@ class Client: ...@@ -247,22 +268,17 @@ class Client:
def user_set(self, key, val): def user_set(self, key, val):
"""Set user defined key-value pairs across processes.""" """Set user defined key-value pairs across processes."""
self.proxy.user_set(key, val) return self.proxy.user_set(key, val)
def user_get(self, key): def user_get(self, key):
"""Get user defined key-value pairs across processes.""" """Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)
def bcast_val(self, val, key, size): def bcast_val(self, val, key, size):
if val is not None: idx = self.bcast_dict[key] + 1
self.user_set(key + "_sync", val) self.bcast_dict[key] = idx
self.group_barrier(key, size) key = key + "_bcast_" + str(idx)
self.group_barrier(key, size) return self.proxy.bcast_val(val, key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
def main(port=0, verbose=True): def main(port=0, verbose=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册