未验证 提交 876e2ff1 编写于 作者: C caozhou 提交者: GitHub

[auto parallel] remove comm init control (#44385)

上级 c0a7830f
...@@ -324,65 +324,11 @@ class Engine: ...@@ -324,65 +324,11 @@ class Engine:
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
has_recv_by_socket = [] # NOTE: add the comm init control in the future for auto search
# This is a magic number and the rank number for training is usually less than 5000
magic_num = 5000
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups: for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks: if self._cur_rank not in process_group.ranks:
continue continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(self._cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(self._cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".
format(process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept(
)
rank = int(
client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(self._cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print(
"It is able to instantiate {} as recver now."
.format(process_group.ranks))
break
process_group.instantiate() process_group.instantiate()
server_socket.close()
self._place = _get_device() self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace): if isinstance(self._place, fluid.CUDAPlace):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册