diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index 27f23dfd1d1af37f09f34d215f808e88b46835bc..e83efb926c3bdaefa375eac4f264e58f57b808a8 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -36,6 +36,7 @@ class Methods: self.dict_barrier_counter = defaultdict(int) self.dict_barrier_event = defaultdict(threading.Event) self.user_dict = defaultdict(partial(Future, False)) + self.bcast_dict = {} def connect(self): """Method for checking connection success.""" @@ -127,6 +128,23 @@ class Methods: future = self.user_dict[key] return future.get() + def bcast_val(self, val, key, size): + with self.lock: + if key not in self.bcast_dict: + self.bcast_dict[key] = [Future(False), size] + arr = self.bcast_dict[key] + if val is not None: + arr[0].set(val) + val = None + else: + val = arr[0].get() + with self.lock: + cnt = arr[1] - 1 + arr[1] = cnt + if cnt == 0: + del self.bcast_dict[key] + return val + class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass @@ -142,7 +160,9 @@ def _start_server(py_server_port, queue): """ try: mm_server_port = create_mm_server("0.0.0.0", 0) - server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) + server = ThreadXMLRPCServer( + ("0.0.0.0", py_server_port), logRequests=False, allow_none=True + ) server.register_instance(Methods(mm_server_port)) _, py_server_port = server.server_address queue.put((py_server_port, mm_server_port)) @@ -185,13 +205,14 @@ class Client: self.master_ip = master_ip self.port = port self.connect() + self.bcast_dict = defaultdict(lambda: 0) def connect(self): """Check connection success.""" while True: try: self.proxy = ServerProxy( - "http://{}:{}".format(self.master_ip, self.port) + "http://{}:{}".format(self.master_ip, self.port), allow_none=True ) if self.proxy.connect(): break @@ -247,22 +268,17 @@ class Client: def user_set(self, key, val): """Set user defined key-value pairs across processes.""" - self.proxy.user_set(key, val) + return self.proxy.user_set(key, val) def user_get(self, key): """Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) def bcast_val(self, val, key, size): - if val is not None: - self.user_set(key + "_sync", val) - self.group_barrier(key, size) - self.group_barrier(key, size) - else: - self.group_barrier(key, size) - val = self.user_get(key + "_sync") - self.group_barrier(key, size) - return val + idx = self.bcast_dict[key] + 1 + self.bcast_dict[key] = idx + key = key + "_bcast_" + str(idx) + return self.proxy.bcast_val(val, key, size) def main(port=0, verbose=True):