未验证 提交 3cbf0e93 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph pp] all sync for allgather partial (#46483)

上级 cee2b12d
...@@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree): ...@@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id): rank_id):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', dst_rank_in_group, 'num', 'peer', dst, 'num', nranks, 'id',
nranks, 'id', rank_id) rank_id)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
return group.process_group.send_partial(tensor, dst_rank_in_group, return group.process_group.send_partial(tensor, dst, nranks, rank_id)
nranks, rank_id)
def send_partial(tensor, def send_partial(tensor,
...@@ -189,13 +187,12 @@ def send_partial(tensor, ...@@ -189,13 +187,12 @@ def send_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id, return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst,
dst_rank, nranks, rank_id) nranks, rank_id)
else: else:
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _in_legacy_dygraph(): if _in_legacy_dygraph():
send_op = paddle.distributed.send send_op = paddle.distributed.send
elif in_dygraph_mode(): elif in_dygraph_mode():
...@@ -205,23 +202,22 @@ def send_partial(tensor, ...@@ -205,23 +202,22 @@ def send_partial(tensor,
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id): rank_id):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
assert use_calc_stream assert use_calc_stream
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', src_rank_in_group, 'num', 'peer', src, 'num', nranks, 'id',
nranks, 'id', rank_id, 'dtype', rank_id, 'dtype', tensor.dtype,
tensor.dtype, 'out_shape', 'out_shape', tensor.shape)
tensor.shape)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
task = group.process_group.recv_partial(tensor, src_rank_in_group, task = group.process_group.recv_partial(tensor, src, nranks, rank_id)
nranks, rank_id)
if use_calc_stream: if use_calc_stream:
task.wait() task.wait()
return task return None
else:
return task
def recv_partial(tensor, def recv_partial(tensor,
...@@ -235,13 +231,12 @@ def recv_partial(tensor, ...@@ -235,13 +231,12 @@ def recv_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src,
src_rank, nranks, rank_id) nranks, rank_id)
else: else:
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _in_legacy_dygraph() or use_calc_stream: if _in_legacy_dygraph() or use_calc_stream:
recv_op = paddle.distributed.recv recv_op = paddle.distributed.recv
elif in_dygraph_mode(): elif in_dygraph_mode():
...@@ -260,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, ...@@ -260,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
return group.process_group.all_gather_partial(tensor, tensor, nranks, task = group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id) rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def allgather_partial(tensor, def allgather_partial(tensor,
...@@ -270,9 +270,9 @@ def allgather_partial(tensor, ...@@ -270,9 +270,9 @@ def allgather_partial(tensor,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
if not _is_valid_send_recv_partial(tensor, nranks): if not _is_valid_send_recv_partial(tensor, nranks):
return None return tensor
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return None return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, return _partial_allgather_op(tensor, group, use_calc_stream, ring_id,
...@@ -335,8 +335,7 @@ def _p2p_helper(tensor_send_next, ...@@ -335,8 +335,7 @@ def _p2p_helper(tensor_send_next,
if tensor_send_prev is not None: if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple): if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev: for d in tensor_send_prev:
if _in_legacy_dygraph(): paddle.distributed.wait(d, use_calc_stream=True)
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, send_partial(d,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
...@@ -344,8 +343,7 @@ def _p2p_helper(tensor_send_next, ...@@ -344,8 +343,7 @@ def _p2p_helper(tensor_send_next,
group=_hcg.send_prev_group, group=_hcg.send_prev_group,
use_calc_stream=False) use_calc_stream=False)
else: else:
if _in_legacy_dygraph(): paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(tensor_send_prev, send_partial(tensor_send_prev,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
...@@ -356,27 +354,40 @@ def _p2p_helper(tensor_send_next, ...@@ -356,27 +354,40 @@ def _p2p_helper(tensor_send_next,
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple): if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev: for d in tensor_recv_prev:
tasks.append( task = recv_partial(d,
recv_partial(d, src=0,
src=0, nranks=mp_degree,
nranks=mp_degree, rank_id=mp_rank,
rank_id=mp_rank, group=_hcg.recv_prev_group,
group=_hcg.recv_prev_group, use_calc_stream=sync_recv)
use_calc_stream=sync_recv)) if sync_recv:
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
else: else:
tasks.append( task = recv_partial(tensor_recv_prev,
recv_partial(tensor_recv_prev, src=0,
src=0, nranks=mp_degree,
nranks=mp_degree, rank_id=mp_rank,
rank_id=mp_rank, group=_hcg.recv_prev_group,
group=_hcg.recv_prev_group, use_calc_stream=sync_recv)
use_calc_stream=sync_recv)) if sync_recv:
allgather_partial(tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
if tensor_send_next is not None: if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple): if isinstance(tensor_send_next, tuple):
for d in tensor_send_next: for d in tensor_send_next:
if _in_legacy_dygraph(): paddle.distributed.wait(d, use_calc_stream=True)
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, send_partial(d,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
...@@ -384,8 +395,7 @@ def _p2p_helper(tensor_send_next, ...@@ -384,8 +395,7 @@ def _p2p_helper(tensor_send_next,
group=_hcg.send_next_group, group=_hcg.send_next_group,
use_calc_stream=False) use_calc_stream=False)
else: else:
if _in_legacy_dygraph(): paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(tensor_send_next, send_partial(tensor_send_next,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
...@@ -396,57 +406,64 @@ def _p2p_helper(tensor_send_next, ...@@ -396,57 +406,64 @@ def _p2p_helper(tensor_send_next,
if tensor_recv_next is not None: if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple): if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next: for d in tensor_recv_next:
tasks.append( task = recv_partial(d,
recv_partial(d, src=1,
src=1, nranks=mp_degree,
nranks=mp_degree, rank_id=mp_rank,
rank_id=mp_rank, group=_hcg.recv_next_group,
group=_hcg.recv_next_group, use_calc_stream=sync_recv)
use_calc_stream=sync_recv)) if sync_recv:
allgather_partial(d,
else: nranks=mp_degree,
tasks.append( rank_id=mp_rank,
recv_partial(tensor_recv_next, group=mp_group,
src=1, use_calc_stream=True)
nranks=mp_degree, else:
rank_id=mp_rank, tasks.append(task)
group=_hcg.recv_next_group,
use_calc_stream=sync_recv))
if not sync_recv and in_dygraph_mode():
# wait irecv tasks in eager dygraph mode with new comm library
for task in tasks:
assert task is not None
task.wait()
tensors_for_all_gather = []
if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
tensors_for_all_gather.append(d)
else: else:
tensors_for_all_gather.append(tensor_recv_prev) task = recv_partial(tensor_recv_next,
if tensor_recv_next is not None: src=1,
if isinstance(tensor_recv_next, tuple): nranks=mp_degree,
for d in tensor_recv_next: rank_id=mp_rank,
tensors_for_all_gather.append(d) group=_hcg.recv_next_group,
else: use_calc_stream=sync_recv)
tensors_for_all_gather.append(tensor_recv_next) if sync_recv:
allgather_partial(tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
tasks.append(task)
if not sync_recv:
if in_dygraph_mode():
# wait irecv tasks in eager dygraph mode with new comm library
for task in tasks:
assert task is not None
task.wait()
tasks = [] tensors_for_all_gather = []
for tensor in tensors_for_all_gather: if tensor_recv_prev is not None:
tasks.append( if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
tensors_for_all_gather.append(d)
else:
tensors_for_all_gather.append(tensor_recv_prev)
if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
tensors_for_all_gather.append(d)
else:
tensors_for_all_gather.append(tensor_recv_next)
for tensor in tensors_for_all_gather:
allgather_partial(tensor, allgather_partial(tensor,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True)) use_calc_stream=True)
if in_dygraph_mode():
for task in tasks:
# wait partial all gather tasks
if task is not None:
task.wait()
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册