From 876e2ff1f62fde4d3dc56f7dd1403c659fbfb0b9 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Mon, 18 Jul 2022 10:06:23 +0800 Subject: [PATCH] [auto parallel] remove comm init control (#44385) --- .../distributed/auto_parallel/engine.py | 56 +------------------ 1 file changed, 1 insertion(+), 55 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 1e1e37b443..72a377603e 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -324,65 +324,11 @@ class Engine: # instantiate communication by process_mapping. all_process_groups = get_all_process_groups() - has_recv_by_socket = [] - # This is a magic number and the rank number for training is usually less than 5000 - magic_num = 5000 - genv = _get_global_env() - cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":") - cur_rank_recv_port = int(cur_rank_port) + magic_num - server_socket = None - # Large enough for recv rank - buff_size = 1024 - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.bind((cur_rank_ip, cur_rank_recv_port)) - # The 10 is an empirical value - server_socket.listen(10) - client_sockets = {} + # NOTE: add the comm init control in the future for auto search for process_group in all_process_groups: if self._cur_rank not in process_group.ranks: continue - if len(process_group.ranks) == 2: - index = process_group.ranks.index(self._cur_rank) - is_send = True if index == 0 else False - if is_send: - recv_rank = process_group.ranks[1] - recv_rank_ip, recv_rank_port = genv.trainer_endpoints[ - recv_rank].split(":") - connect_port = int(recv_rank_port) + magic_num - client_socket = socket.socket(socket.AF_INET, - socket.SOCK_STREAM) - client_socket.connect((recv_rank_ip, connect_port)) - client_socket.send(str(self._cur_rank).encode('utf-8')) - rank = client_socket.recv(buff_size).decode('utf-8') - rank = int(rank) - if rank != recv_rank: - raise ValueError( - "Please check comm pair, the recv rank should be {} but got {}." - .format(recv_rank, rank)) - else: - print("It is able to instantiate {} as sender now.". - format(process_group.ranks)) - client_socket.close() - else: - send_rank = process_group.ranks[0] - while True: - if send_rank not in has_recv_by_socket: - client_socket, recv_addr = server_socket.accept( - ) - rank = int( - client_socket.recv(buff_size).decode()) - client_sockets[rank] = client_socket - has_recv_by_socket.append(rank) - else: - client_sockets[send_rank].send( - str(self._cur_rank).encode("utf-8")) - client_sockets[send_rank].close() - print( - "It is able to instantiate {} as recver now." - .format(process_group.ranks)) - break process_group.instantiate() - server_socket.close() self._place = _get_device() if isinstance(self._place, fluid.CUDAPlace): -- GitLab