未验证 提交 56fd25b8 编写于 作者: L Leo Chen 提交者: GitHub

eager call all2all to avoid p2p hang in lazy init (#54431)

* eager call all2all to avoid p2p hang in lazy init

* update
上级 2a2af7d7
......@@ -55,7 +55,9 @@ def remove_process_group(ring_id):
_g_process_group_map.pop(ring_id)
def new_process_group(ranks, group_id=None, force_new_group=False):
def new_process_group(
ranks, group_id=None, force_new_group=False, group_type=None
):
global _g_process_group_map
if not force_new_group:
......@@ -72,8 +74,9 @@ def new_process_group(ranks, group_id=None, force_new_group=False):
if group_id is None:
group_id = _new_ring_id() + num_groups + 1
new_pg = ProcessGroup(group_id, ranks)
new_pg = ProcessGroup(group_id, ranks, group_type)
_g_process_group_map[group_id] = new_pg
return new_pg
......@@ -84,7 +87,7 @@ def new_process_group(ranks, group_id=None, force_new_group=False):
# the instantiation process in a more general way. In the future, the process group may
# handle the communication implementation choice.
class ProcessGroup:
def __init__(self, group_id, ranks):
def __init__(self, group_id, ranks, group_type=None):
if group_id == 0 and get_process_group(0) is not None:
assert (
group_id != 0
......@@ -96,6 +99,7 @@ class ProcessGroup:
global _g_process_group_map
_g_process_group_map[0].add_ranks(ranks)
self._is_instantiate = False
self._group_type = group_type
@property
def id(self):
......@@ -109,6 +113,10 @@ class ProcessGroup:
def nranks(self):
return len(self._ranks)
@property
def group_type(self):
return self._group_type
def add_ranks(self, new_ranks):
if set(new_ranks) <= set(self.ranks):
return
......@@ -192,6 +200,16 @@ class ProcessGroup:
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
# NOTE(zhiqiu): to avoid send/recv hang in lazy init
if self._group_type == 'p2p':
alltoall_tmp = paddle.empty(
shape=[self.nranks, self.nranks], dtype="int32"
)
paddle._legacy_C_ops.alltoall(
alltoall_tmp, 'use_calc_stream', True, 'ring_id', ring_id
)
paddle.device.cuda.synchronize()
if self.nranks > 1:
barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
......
......@@ -337,7 +337,7 @@ class Inserter:
"""Insert send op into block at the given index."""
op_type = 'send_v2'
# use pair comm group
process_group = new_process_group([src, dst])
process_group = new_process_group([src, dst], group_type='p2p')
send_op = block._insert_op(
idx,
type=op_type,
......@@ -357,7 +357,7 @@ class Inserter:
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
# use pair group
process_group = new_process_group([src, dst])
process_group = new_process_group([src, dst], group_type='p2p')
recv_op = block._insert_op(
idx,
type=op_type,
......@@ -1794,9 +1794,13 @@ class Resharder:
elif isinstance(op_desc, AllGatherConcatOpDesc):
new_process_group(op_desc.group)
elif isinstance(op_desc, SendOpDesc):
new_process_group([op_desc.src, op_desc.dst])
new_process_group(
[op_desc.src, op_desc.dst], group_type='p2p'
)
elif isinstance(op_desc, RecvOpDesc):
new_process_group([op_desc.src, op_desc.dst])
new_process_group(
[op_desc.src, op_desc.dst], group_type='p2p'
)
tensor_list = []
partition_tensor_list = []
......@@ -2721,7 +2725,10 @@ class Resharder:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
new_process_group(
[item, recv_rank],
group_type='p2p',
)
else:
for index, tensor_process in enumerate(
tensor_processes
......@@ -2748,7 +2755,9 @@ class Resharder:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
new_process_group(
[item, recv_rank], group_type='p2p'
)
cur_op_count = len(block.ops)
idx_offset = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册