未验证 提交 7dc7fc4b 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add comm init control by socket (#44148)

* add comm init control by socket

* avoid single card instance failure
上级 42468de1
......@@ -15,6 +15,7 @@
import copy
import logging
from collections import defaultdict
import socket
import paddle
import paddle.utils as utils
......@@ -36,7 +37,8 @@ from paddle.distributed import fleet
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
# from .cluster import Cluster, get_default_cluster
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator
......@@ -60,8 +62,8 @@ class Engine:
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster
# if self.cluster is None:
# self.cluster = get_default_cluster()
if self.cluster is None:
self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
......@@ -314,10 +316,66 @@ class Engine:
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
has_recv_by_socket = []
# This is a magic number and the rank number for training is usually less than 5000
magic_num = 5000
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(self._cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(self._cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".
format(process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept(
)
rank = int(
client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(self._cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print(
"It is able to instantiate {} as recver now."
.format(process_group.ranks))
break
process_group.instantiate()
server_socket.close()
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License
from collections import OrderedDict
import paddle
import paddle.fluid.core as core
from ..collective import _get_global_env
......@@ -132,14 +134,21 @@ class ProcessGroup:
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
paddle.framework._in_legacy_dygraph()
paddle.set_device('gpu:%d' %
paddle.distributed.ParallelEnv().dev_id)
tmp = paddle.to_tensor(
[1], dtype="int32") if _non_static_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
paddle.distributed.all_reduce(tmp, use_calc_stream=True, group=self)
paddle.distributed.wait(tmp, group=self)
paddle.enable_static()
self._is_instantiate = True
def is_member(self):
return True
# def __eq__(self, other):
# if not isinstance(other, ProcessGroup):
# return False
......@@ -158,5 +167,5 @@ class ProcessGroup:
# Note that Process group 0 is reserved for representing all ranks.
# At the beginning, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {}
_g_process_group_map = OrderedDict()
_g_process_group_map[0] = ProcessGroup(0, [])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册