未验证 提交 56fd25b8 编写于 作者: L Leo Chen 提交者: GitHub

eager call all2all to avoid p2p hang in lazy init (#54431)

* eager call all2all to avoid p2p hang in lazy init

* update
上级 2a2af7d7
...@@ -55,7 +55,9 @@ def remove_process_group(ring_id): ...@@ -55,7 +55,9 @@ def remove_process_group(ring_id):
_g_process_group_map.pop(ring_id) _g_process_group_map.pop(ring_id)
def new_process_group(ranks, group_id=None, force_new_group=False): def new_process_group(
ranks, group_id=None, force_new_group=False, group_type=None
):
global _g_process_group_map global _g_process_group_map
if not force_new_group: if not force_new_group:
...@@ -72,8 +74,9 @@ def new_process_group(ranks, group_id=None, force_new_group=False): ...@@ -72,8 +74,9 @@ def new_process_group(ranks, group_id=None, force_new_group=False):
if group_id is None: if group_id is None:
group_id = _new_ring_id() + num_groups + 1 group_id = _new_ring_id() + num_groups + 1
new_pg = ProcessGroup(group_id, ranks) new_pg = ProcessGroup(group_id, ranks, group_type)
_g_process_group_map[group_id] = new_pg _g_process_group_map[group_id] = new_pg
return new_pg return new_pg
...@@ -84,7 +87,7 @@ def new_process_group(ranks, group_id=None, force_new_group=False): ...@@ -84,7 +87,7 @@ def new_process_group(ranks, group_id=None, force_new_group=False):
# the instantiation process in a more general way. In the future, the process group may # the instantiation process in a more general way. In the future, the process group may
# handle the communication implementation choice. # handle the communication implementation choice.
class ProcessGroup: class ProcessGroup:
def __init__(self, group_id, ranks): def __init__(self, group_id, ranks, group_type=None):
if group_id == 0 and get_process_group(0) is not None: if group_id == 0 and get_process_group(0) is not None:
assert ( assert (
group_id != 0 group_id != 0
...@@ -96,6 +99,7 @@ class ProcessGroup: ...@@ -96,6 +99,7 @@ class ProcessGroup:
global _g_process_group_map global _g_process_group_map
_g_process_group_map[0].add_ranks(ranks) _g_process_group_map[0].add_ranks(ranks)
self._is_instantiate = False self._is_instantiate = False
self._group_type = group_type
@property @property
def id(self): def id(self):
...@@ -109,6 +113,10 @@ class ProcessGroup: ...@@ -109,6 +113,10 @@ class ProcessGroup:
def nranks(self): def nranks(self):
return len(self._ranks) return len(self._ranks)
@property
def group_type(self):
return self._group_type
def add_ranks(self, new_ranks): def add_ranks(self, new_ranks):
if set(new_ranks) <= set(self.ranks): if set(new_ranks) <= set(self.ranks):
return return
...@@ -192,6 +200,16 @@ class ProcessGroup: ...@@ -192,6 +200,16 @@ class ProcessGroup:
barrier_tensor, barrier_tensor, 'ring_id', ring_id barrier_tensor, barrier_tensor, 'ring_id', ring_id
) )
# NOTE(zhiqiu): to avoid send/recv hang in lazy init
if self._group_type == 'p2p':
alltoall_tmp = paddle.empty(
shape=[self.nranks, self.nranks], dtype="int32"
)
paddle._legacy_C_ops.alltoall(
alltoall_tmp, 'use_calc_stream', True, 'ring_id', ring_id
)
paddle.device.cuda.synchronize()
if self.nranks > 1: if self.nranks > 1:
barrier_tensor = paddle.full([1], 1, dtype="int32") barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier( paddle._legacy_C_ops.barrier(
......
...@@ -337,7 +337,7 @@ class Inserter: ...@@ -337,7 +337,7 @@ class Inserter:
"""Insert send op into block at the given index.""" """Insert send op into block at the given index."""
op_type = 'send_v2' op_type = 'send_v2'
# use pair comm group # use pair comm group
process_group = new_process_group([src, dst]) process_group = new_process_group([src, dst], group_type='p2p')
send_op = block._insert_op( send_op = block._insert_op(
idx, idx,
type=op_type, type=op_type,
...@@ -357,7 +357,7 @@ class Inserter: ...@@ -357,7 +357,7 @@ class Inserter:
"""Insert recv op into block at the given index.""" """Insert recv op into block at the given index."""
op_type = 'recv_v2' op_type = 'recv_v2'
# use pair group # use pair group
process_group = new_process_group([src, dst]) process_group = new_process_group([src, dst], group_type='p2p')
recv_op = block._insert_op( recv_op = block._insert_op(
idx, idx,
type=op_type, type=op_type,
...@@ -1794,9 +1794,13 @@ class Resharder: ...@@ -1794,9 +1794,13 @@ class Resharder:
elif isinstance(op_desc, AllGatherConcatOpDesc): elif isinstance(op_desc, AllGatherConcatOpDesc):
new_process_group(op_desc.group) new_process_group(op_desc.group)
elif isinstance(op_desc, SendOpDesc): elif isinstance(op_desc, SendOpDesc):
new_process_group([op_desc.src, op_desc.dst]) new_process_group(
[op_desc.src, op_desc.dst], group_type='p2p'
)
elif isinstance(op_desc, RecvOpDesc): elif isinstance(op_desc, RecvOpDesc):
new_process_group([op_desc.src, op_desc.dst]) new_process_group(
[op_desc.src, op_desc.dst], group_type='p2p'
)
tensor_list = [] tensor_list = []
partition_tensor_list = [] partition_tensor_list = []
...@@ -2721,7 +2725,10 @@ class Resharder: ...@@ -2721,7 +2725,10 @@ class Resharder:
# Ensure every rank has a global view of communicator groups for entire cluters. # Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could # When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization. # conduct a correct global synchronization.
new_process_group([item, recv_rank]) new_process_group(
[item, recv_rank],
group_type='p2p',
)
else: else:
for index, tensor_process in enumerate( for index, tensor_process in enumerate(
tensor_processes tensor_processes
...@@ -2748,7 +2755,9 @@ class Resharder: ...@@ -2748,7 +2755,9 @@ class Resharder:
# Ensure every rank has a global view of communicator groups for entire cluters. # Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could # When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization. # conduct a correct global synchronization.
new_process_group([item, recv_rank]) new_process_group(
[item, recv_rank], group_type='p2p'
)
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = ( idx_offset = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册