未验证 提交 98bf4e95 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed] Support p2p on calc stream for pipelineparallel (#54706)

* add p2p calc stream

* rm code

* rm code

* rm assert

* rm code
上级 ba75fbec
...@@ -163,10 +163,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -163,10 +163,6 @@ class PipelineParallel(MetaParallelBase):
"pp_configs" "pp_configs"
].enable_timer ].enable_timer
assert (
not self._dp_comm_overlap and not self._sharding_comm_overlap
), "Comm overlap is not supported now."
if self._dp_comm_overlap: if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1 assert self.use_data_parallel and self.num_stages > 1
...@@ -269,6 +265,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -269,6 +265,7 @@ class PipelineParallel(MetaParallelBase):
for model in models: for model in models:
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance. # different groups to get the best performance.
parameter_list = [ parameter_list = [
p for p in model.parameters() if not p.stop_gradient p for p in model.parameters() if not p.stop_gradient
] ]
......
...@@ -176,30 +176,63 @@ def _is_valid_send_recv_partial(tensor, mp_degree): ...@@ -176,30 +176,63 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
if not _enable_partial_send_recv: if not _enable_partial_send_recv:
return False return False
tensor_numel = np.prod(tensor.shape) tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element" assert tensor_numel > 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0 return mp_degree > 1 and tensor_numel % mp_degree == 0
def _partial_send_op(tensor, group, dst, nranks, rank_id): def _send_on_calc_stream(tensor, group, dst, nranks=1, rank_id=0):
assert ( assert (
group is not None group is not None
), "Group should be an instance for _partial_send_op." ), "Group should be an instance for _send_on_calc_stream."
dst_rank_in_group = group.get_group_rank(dst) dst_rank_in_group = group.get_group_rank(dst)
if framework.in_dygraph_mode(): if _is_valid_send_recv_partial(tensor, nranks):
return group.process_group.send_partial( return group.process_group.send_partial_on_calc_stream(
tensor, dst_rank_in_group, nranks, rank_id tensor, dst_rank_in_group, nranks, rank_id
) )
else:
return group.process_group.send_on_calc_stream(
tensor, dst_rank_in_group
)
def _partial_recv_op(tensor, group, src, nranks, rank_id): def _recv_on_calc_stream(tensor, group, src, nranks=1, rank_id=0):
assert ( assert (
group is not None group is not None
), "Group should be an instance for _partial_recv_op." ), "Group should be an instance for _recv_on_calc_stream."
src_rank_in_group = group.get_group_rank(src) src_rank_in_group = group.get_group_rank(src)
if framework.in_dygraph_mode(): if _is_valid_send_recv_partial(tensor, nranks):
return group.process_group.recv_partial( return group.process_group.recv_partial_on_calc_stream(
tensor, src_rank_in_group, nranks, rank_id tensor, src_rank_in_group, nranks, rank_id
) )
else:
return group.process_group.recv_on_calc_stream(
tensor, src_rank_in_group
)
class P2PonCalcStream:
def __init__(self, op, tensor, peer, group, nranks=1, rank_id=0):
"""
Args:
op (function): The function to be executed on the calc stream.
tensor (Tensor): The tensor to be sent or received.
peer (int): The peer rank.
group (Group): The process group to p2p.
nranks (int): The number of ranks in model parallel group.
rank_id (int): The rank id in the model parallel group.
"""
if op not in [_send_on_calc_stream, _recv_on_calc_stream]:
raise RuntimeError(
"Invalid ``op`` function. Expected ``op`` "
"to be of type ``_send_on_calc_stream`` or "
"``_recv_on_calc_stream``."
)
self.op = op
self.tensor = tensor
self.peer = peer
self.group = group
self.nranks = nranks
self.rank_id = rank_id
def _partial_allgather_op( def _partial_allgather_op(
...@@ -232,46 +265,39 @@ def allgather_partial( ...@@ -232,46 +265,39 @@ def allgather_partial(
) )
def partial_batch_isend_irecv(p2p_op_list): def batch_send_recv_on_calc_stream(p2p_op_list):
group = p2p_op_list[0].group group = p2p_op_list[0].group
if _warn_cur_rank_not_in_group(group): if _warn_cur_rank_not_in_group(group):
return return
group = _get_global_group() if group is None else group
if framework.in_dygraph_mode(): backend = group.backend
group = _get_global_group() if group is None else group with _with_batch_p2p_guard(backend):
backend = group.backend for p2p_op in p2p_op_list:
tasks = [] op = p2p_op.op
with _with_batch_p2p_guard(backend): tensor = p2p_op.tensor
for p2p_op in p2p_op_list: peer = p2p_op.peer
op = p2p_op.op comm_group = p2p_op.group
tensor = p2p_op.tensor nranks = p2p_op.nranks
peer = p2p_op.peer rank_id = p2p_op.rank_id
comm_group = p2p_op.group op(tensor, comm_group, peer, nranks, rank_id)
nranks = p2p_op.nranks
rank_id = p2p_op.rank_id
task = op(tensor, comm_group, peer, nranks, rank_id) def _process_p2p_tuple_or_tensor(
if task is not None: tensors, p2p_func, pp_rank, pp_group, mp_degree=1, mp_rank=0
tasks.append(task) ):
return tasks ops = []
else: if isinstance(tensors, tuple):
raise RuntimeError("Don't support static graph mode currently.") for tensor in tensors:
op = P2PonCalcStream(
p2p_func, tensor, pp_rank, pp_group, mp_degree, mp_rank
class PartialP2POp:
def __init__(self, op, nranks, rank_id, tensor, peer, group):
if op not in [_partial_recv_op, _partial_send_op]:
raise RuntimeError(
"Invalid ``op`` function. Expected ``op`` "
"to be of type ``_partial_send_op`` or "
"``_partial_recv_op``."
) )
ops.append(op)
self.op = op else:
self.nranks = nranks op = P2PonCalcStream(
self.rank_id = rank_id p2p_func, tensors, pp_rank, pp_group, mp_degree, mp_rank
self.tensor = tensor )
self.peer = peer ops.append(op)
self.group = group return ops
def _p2p_helper( def _p2p_helper(
...@@ -306,7 +332,6 @@ def _p2p_helper( ...@@ -306,7 +332,6 @@ def _p2p_helper(
tensor_recv_prev.append(tmp) tensor_recv_prev.append(tmp)
tensor_recv_prev = tuple(tensor_recv_prev) tensor_recv_prev = tuple(tensor_recv_prev)
else: else:
tensor_recv_prev = paddle.empty( tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg) shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)
) )
...@@ -328,189 +353,60 @@ def _p2p_helper( ...@@ -328,189 +353,60 @@ def _p2p_helper(
) )
ops = [] ops = []
partial_ops = []
pipe_group = _hcg.get_pipe_parallel_group() pipe_group = _hcg.get_pipe_parallel_group()
# start to p2p communicate # start to p2p communicate
if tensor_send_prev is not None: if tensor_send_prev is not None:
src_rank = _hcg._get_p2p_prev_rank() src_rank = _hcg._get_p2p_prev_rank()
if isinstance(tensor_send_prev, tuple): ops.extend(
for d in tensor_send_prev: _process_p2p_tuple_or_tensor(
if _is_valid_send_recv_partial(d, mp_degree): tensor_send_prev,
op = PartialP2POp( _send_on_calc_stream,
_partial_send_op, src_rank,
mp_degree, pipe_group,
mp_rank, mp_degree,
d, mp_rank,
src_rank, )
pipe_group, )
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.isend,
d,
src_rank,
pipe_group,
)
ops.append(op)
else:
if _is_valid_send_recv_partial(tensor_send_prev, mp_degree):
op = PartialP2POp(
_partial_send_op,
mp_degree,
mp_rank,
tensor_send_prev,
src_rank,
pipe_group,
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.isend,
tensor_send_prev,
src_rank,
pipe_group,
)
ops.append(op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
dst_rank = _hcg._get_p2p_prev_rank() dst_rank = _hcg._get_p2p_prev_rank()
if isinstance(tensor_recv_prev, tuple): ops.extend(
for d in tensor_recv_prev: _process_p2p_tuple_or_tensor(
if _is_valid_send_recv_partial(d, mp_degree): tensor_recv_prev,
op = PartialP2POp( _recv_on_calc_stream,
_partial_recv_op, dst_rank,
mp_degree, pipe_group,
mp_rank, mp_degree,
d, mp_rank,
dst_rank, )
pipe_group, )
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.irecv,
d,
dst_rank,
pipe_group,
)
ops.append(op)
else:
if _is_valid_send_recv_partial(tensor_recv_prev, mp_degree):
op = PartialP2POp(
_partial_recv_op,
mp_degree,
mp_rank,
tensor_recv_prev,
dst_rank,
pipe_group,
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.irecv,
tensor_recv_prev,
dst_rank,
pipe_group,
)
ops.append(op)
if tensor_send_next is not None: if tensor_send_next is not None:
src_rank = _hcg._get_p2p_next_rank() src_rank = _hcg._get_p2p_next_rank()
if isinstance(tensor_send_next, tuple): ops.extend(
for d in tensor_send_next: _process_p2p_tuple_or_tensor(
if _is_valid_send_recv_partial(d, mp_degree): tensor_send_next,
op = PartialP2POp( _send_on_calc_stream,
_partial_send_op, src_rank,
mp_degree, pipe_group,
mp_rank, mp_degree,
d, mp_rank,
src_rank, )
pipe_group, )
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.isend,
d,
src_rank,
pipe_group,
)
ops.append(op)
else:
if _is_valid_send_recv_partial(tensor_send_next, mp_degree):
op = PartialP2POp(
_partial_send_op,
mp_degree,
mp_rank,
tensor_send_next,
src_rank,
pipe_group,
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.isend,
tensor_send_next,
src_rank,
pipe_group,
)
ops.append(op)
if tensor_recv_next is not None: if tensor_recv_next is not None:
dst_rank = _hcg._get_p2p_next_rank() dst_rank = _hcg._get_p2p_next_rank()
if isinstance(tensor_recv_next, tuple): ops.extend(
for d in tensor_recv_next: _process_p2p_tuple_or_tensor(
if _is_valid_send_recv_partial(d, mp_degree): tensor_recv_next,
op = PartialP2POp( _recv_on_calc_stream,
_partial_recv_op, dst_rank,
mp_degree, pipe_group,
mp_rank, mp_degree,
d, mp_rank,
dst_rank, )
pipe_group, )
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.irecv,
d,
dst_rank,
pipe_group,
)
ops.append(op)
else:
if _is_valid_send_recv_partial(tensor_recv_next, mp_degree):
op = PartialP2POp(
_partial_recv_op,
mp_degree,
mp_rank,
tensor_recv_next,
dst_rank,
pipe_group,
)
partial_ops.append(op)
else:
op = paddle.distributed.P2POp(
paddle.distributed.irecv,
tensor_recv_next,
dst_rank,
pipe_group,
)
ops.append(op)
if len(ops) > 0: if len(ops) > 0:
reqs = paddle.distributed.batch_isend_irecv(ops) batch_send_recv_on_calc_stream(ops)
for req in reqs:
req.wait()
if len(partial_ops) > 0:
reqs = partial_batch_isend_irecv(partial_ops)
for req in reqs:
req.wait()
# block cpu to wait the result
paddle.device.synchronize()
tensors_for_all_gather = [] tensors_for_all_gather = []
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册