perf(dist): speed up bcast_val

GitOrigin-RevId: 21c4123b09480b425676681a16a50962141b1eda
......@@ -36,6 +36,7 @@ class Methods:
self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}
def connect(self):
"""Method for checking connection success."""
......@@ -127,6 +128,23 @@ class Methods:
future = self.user_dict[key]
return future.get()
def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
......@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
"""
try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
......@@ -185,13 +205,14 @@ class Client:
self.master_ip = master_ip
self.port = port
self.connect()
self.bcast_dict = defaultdict(lambda: 0)
def connect(self):
"""Check connection success."""
while True:
try:
self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port)
"http://{}:{}".format(self.master_ip, self.port), allow_none=True
)
if self.proxy.connect():
break
......@@ -247,22 +268,17 @@ class Client:
def user_set(self, key, val):
"""Set user defined key-value pairs across processes."""
self.proxy.user_set(key, val)
return self.proxy.user_set(key, val)
def user_get(self, key):
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)
def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx
key = key + "_bcast_" + str(idx)
return self.proxy.bcast_val(val, key, size)
def main(port=0, verbose=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部