未验证 提交 7dc7fc4b 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add comm init control by socket (#44148)

* add comm init control by socket

* avoid single card instance failure
上级 42468de1
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
import socket
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
...@@ -36,7 +37,8 @@ from paddle.distributed import fleet ...@@ -36,7 +37,8 @@ from paddle.distributed import fleet
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
# from .cluster import Cluster, get_default_cluster from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -60,8 +62,8 @@ class Engine: ...@@ -60,8 +62,8 @@ class Engine:
self.inputs_spec = self._validate_spec(inputs_spec) self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec) self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster self.cluster = cluster
# if self.cluster is None: if self.cluster is None:
# self.cluster = get_default_cluster() self.cluster = get_default_cluster()
self.strategy = strategy self.strategy = strategy
if self.strategy is None: if self.strategy is None:
self.strategy = fleet.DistributedStrategy() self.strategy = fleet.DistributedStrategy()
...@@ -314,10 +316,66 @@ class Engine: ...@@ -314,10 +316,66 @@ class Engine:
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
has_recv_by_socket = []
# This is a magic number and the rank number for training is usually less than 5000
magic_num = 5000
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups: for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks: if self._cur_rank not in process_group.ranks:
continue continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(self._cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(self._cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".
format(process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept(
)
rank = int(
client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(self._cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print(
"It is able to instantiate {} as recver now."
.format(process_group.ranks))
break
process_group.instantiate() process_group.instantiate()
server_socket.close()
self._place = _get_device() self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace): if isinstance(self._place, fluid.CUDAPlace):
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from collections import OrderedDict
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from ..collective import _get_global_env from ..collective import _get_global_env
...@@ -130,16 +132,23 @@ class ProcessGroup: ...@@ -130,16 +132,23 @@ class ProcessGroup:
else: else:
assert False, ("No CUDA device found") assert False, ("No CUDA device found")
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group # hang caused by cross-creation of new_group
tmp = paddle.to_tensor( paddle.framework._in_legacy_dygraph()
[1], dtype="int32") if _non_static_mode() else fill_constant( paddle.set_device('gpu:%d' %
[0], dtype="int32", value="1") paddle.distributed.ParallelEnv().dev_id)
paddle.distributed.all_reduce(tmp, use_calc_stream=True) tmp = paddle.to_tensor(
paddle.distributed.wait(tmp) [1], dtype="int32") if _non_static_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True, group=self)
paddle.distributed.wait(tmp, group=self)
paddle.enable_static()
self._is_instantiate = True self._is_instantiate = True
def is_member(self):
return True
# def __eq__(self, other): # def __eq__(self, other):
# if not isinstance(other, ProcessGroup): # if not isinstance(other, ProcessGroup):
# return False # return False
...@@ -158,5 +167,5 @@ class ProcessGroup: ...@@ -158,5 +167,5 @@ class ProcessGroup:
# Note that Process group 0 is reserved for representing all ranks. # Note that Process group 0 is reserved for representing all ranks.
# At the beginning, group 0 is empty and new ranks will be added automatically. # At the beginning, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {} _g_process_group_map = OrderedDict()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册