未验证 提交 72b5b5bf 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph hybrid pp for interleave] The interleave scheduler for pipeline parallel (#45497)

上级 fd86a938
...@@ -2384,7 +2384,7 @@ def isend(tensor, dst, group=None): ...@@ -2384,7 +2384,7 @@ def isend(tensor, dst, group=None):
assert group_dst_rank >= 0, ("dst rank out of group, need global rank") assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
return group.process_group.send(tensor, group_dst_rank) return group.process_group.send(tensor, group_dst_rank)
else: else:
raise RuntimeError("Don't support static graph mode currently.") raise RuntimeError("Only support eager dygraph mode.")
def irecv(tensor, src=None, group=None): def irecv(tensor, src=None, group=None):
...@@ -2433,7 +2433,7 @@ def irecv(tensor, src=None, group=None): ...@@ -2433,7 +2433,7 @@ def irecv(tensor, src=None, group=None):
assert group_src_rank >= 0, ("src rank out of group, need global rank") assert group_src_rank >= 0, ("src rank out of group, need global rank")
return group.process_group.recv(tensor, group_src_rank) return group.process_group.recv(tensor, group_src_rank)
else: else:
raise RuntimeError("Don't support static graph mode currently.") raise RuntimeError("Only support eager dygraph mode.")
class P2POp(object): class P2POp(object):
......
...@@ -240,6 +240,14 @@ class HybridCommunicateGroup(object): ...@@ -240,6 +240,14 @@ class HybridCommunicateGroup(object):
return parallel_group, parallel_comm_group return parallel_group, parallel_comm_group
def _get_p2p_next_rank(self):
assert hasattr(self, 'next_rank'), "next_rank has not been inited"
return self.next_rank
def _get_p2p_prev_rank(self):
assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
return self.prev_rank
def _set_p2p_group(self): def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe') comm_lists = self._topo.get_comm_list('pipe')
...@@ -255,6 +263,10 @@ class HybridCommunicateGroup(object): ...@@ -255,6 +263,10 @@ class HybridCommunicateGroup(object):
next_rank = comm_ranks[(idx + 1) % self._pp_degree] next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree] prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
if self.global_rank == curr_rank:
self.next_rank = next_rank
self.prev_rank = prev_rank
next_group = paddle.distributed.new_group( next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank]) ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank: if self.global_rank == curr_rank:
......
...@@ -24,6 +24,7 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401 ...@@ -24,6 +24,7 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401 from .parallel_layers import get_rng_state_tracker # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401 from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401 from .pipeline_parallel import PipelineParallel # noqa: F401
from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401
__all__ = [] __all__ = []
...@@ -189,7 +189,7 @@ class PipelineLayerChunk(Layer): ...@@ -189,7 +189,7 @@ class PipelineLayerChunk(Layer):
# Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute # Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute
# are in the forward function of PipelineLayer. Any directly call will bring unexpected # are in the forward function of PipelineLayer. Any directly call will bring unexpected
# behavior under recompute circumstance. # behavior under recompute circumstance.
raise NotImplementedError( raise PermissionError(
"The forward function of PipelineLayerChunk cannot be called directly. " "The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer.") "Please call forward function of PipelineLayer.")
...@@ -385,6 +385,9 @@ class PipelineLayer(Layer): ...@@ -385,6 +385,9 @@ class PipelineLayer(Layer):
start_idx + stage + 1]: start_idx + stage + 1]:
return stage return stage
def get_num_virtual_stages(self):
return self._num_virtual_pipeline_stages
def get_model_chunks(self): def get_model_chunks(self):
return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks
......
...@@ -54,7 +54,7 @@ class SendRecvMeta: ...@@ -54,7 +54,7 @@ class SendRecvMeta:
def _recv_shape_dtype(self, group): def _recv_shape_dtype(self, group):
# recv len(shape) # recv len(shape)
dims = paddle.to_tensor([0]) dims = paddle.to_tensor([0])
src_rank = group.ranks[0] src_rank = _hcg._get_p2p_prev_rank()
paddle.distributed.recv(dims, src=src_rank, group=group) paddle.distributed.recv(dims, src=src_rank, group=group)
dims = dims.item() dims = dims.item()
...@@ -74,7 +74,7 @@ class SendRecvMeta: ...@@ -74,7 +74,7 @@ class SendRecvMeta:
def recv_meta(self, group): def recv_meta(self, group):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
src_rank = group.ranks[0] src_rank = _hcg._get_p2p_prev_rank()
paddle.distributed.recv(tensor_type, src=src_rank, group=group) paddle.distributed.recv(tensor_type, src=src_rank, group=group)
tensor_type = tensor_type.item() tensor_type = tensor_type.item()
...@@ -105,7 +105,7 @@ class SendRecvMeta: ...@@ -105,7 +105,7 @@ class SendRecvMeta:
def _send_dims_shape_dtype(self, tensor, group): def _send_dims_shape_dtype(self, tensor, group):
# send len(shape) # send len(shape)
dims = paddle.to_tensor(len(tensor.shape)) dims = paddle.to_tensor(len(tensor.shape))
dst_rank = group.ranks[1] dst_rank = _hcg._get_p2p_next_rank()
paddle.distributed.send(dims, dst=dst_rank, group=group) paddle.distributed.send(dims, dst=dst_rank, group=group)
...@@ -122,7 +122,7 @@ class SendRecvMeta: ...@@ -122,7 +122,7 @@ class SendRecvMeta:
paddle.distributed.send(stop_grad, dst=dst_rank, group=group) paddle.distributed.send(stop_grad, dst=dst_rank, group=group)
def send_meta(self, tensor, group): def send_meta(self, tensor, group):
dst_rank = group.ranks[1] dst_rank = _hcg._get_p2p_next_rank()
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
...@@ -165,20 +165,17 @@ def _is_valid_send_recv_partial(tensor, mp_degree): ...@@ -165,20 +165,17 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id): rank_id):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', dst, 'num', nranks, 'id', 'peer', dst_rank_in_group, 'num',
rank_id) nranks, 'id', rank_id)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
task = group.process_group.send_partial(tensor, dst, nranks, rank_id) return group.process_group.send_partial(tensor, dst_rank_in_group,
if use_calc_stream: nranks, rank_id)
task.wait()
return None
else:
return task
def send_partial(tensor, def send_partial(tensor,
...@@ -192,33 +189,35 @@ def send_partial(tensor, ...@@ -192,33 +189,35 @@ def send_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
dst_rank = _hcg._get_p2p_next_rank(
) if dst == 1 else _hcg._get_p2p_prev_rank()
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, return _partial_send_op(tensor, group, use_calc_stream, ring_id,
nranks, rank_id) dst_rank, nranks, rank_id)
else: else:
return paddle.distributed.send(tensor.detach(), if _in_legacy_dygraph():
dst=group.ranks[dst], send_op = paddle.distributed.send
group=group, elif in_dygraph_mode():
use_calc_stream=use_calc_stream) send_op = paddle.distributed.isend
return send_op(tensor.detach(), dst=dst_rank, group=group)
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id): rank_id):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, use_calc_stream, 'ring_id', ring_id,
'peer', src, 'num', nranks, 'id', 'peer', src_rank_in_group, 'num',
rank_id, 'dtype', tensor.dtype, nranks, 'id', rank_id, 'dtype',
'out_shape', tensor.shape) tensor.dtype, 'out_shape',
tensor.shape)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
task = group.process_group.recv_partial(tensor, src, nranks, rank_id) return group.process_group.recv_partial(tensor, src_rank_in_group,
if use_calc_stream: nranks, rank_id)
task.wait()
return None
else:
return task
def recv_partial(tensor, def recv_partial(tensor,
...@@ -232,14 +231,18 @@ def recv_partial(tensor, ...@@ -232,14 +231,18 @@ def recv_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
src_rank = _hcg._get_p2p_prev_rank(
) if src == 0 else _hcg._get_p2p_next_rank()
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, return _partial_recv_op(tensor, group, use_calc_stream, ring_id,
nranks, rank_id) src_rank, nranks, rank_id)
else: else:
return paddle.distributed.recv(tensor.detach(), if _in_legacy_dygraph():
src=group.ranks[src], recv_op = paddle.distributed.recv
group=group, elif in_dygraph_mode():
use_calc_stream=use_calc_stream) recv_op = paddle.distributed.irecv
return recv_op(tensor.detach(), src=src_rank, group=group)
def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
...@@ -253,13 +256,8 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, ...@@ -253,13 +256,8 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = paddle.distributed.collective._get_default_group(
) if group is None else group ) if group is None else group
task = group.process_group.all_gather_partial(tensor, tensor, nranks, return group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id) rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def allgather_partial(tensor, def allgather_partial(tensor,
...@@ -268,9 +266,9 @@ def allgather_partial(tensor, ...@@ -268,9 +266,9 @@ def allgather_partial(tensor,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
if not _is_valid_send_recv_partial(tensor, nranks): if not _is_valid_send_recv_partial(tensor, nranks):
return tensor return None
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return None
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, return _partial_allgather_op(tensor, group, use_calc_stream, ring_id,
...@@ -323,105 +321,124 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -323,105 +321,124 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
tensor_recv_next = paddle.empty( tensor_recv_next = paddle.empty(
shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg))
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = []
# start to p2p communicate # start to p2p communicate
if tensor_send_prev is not None: if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple): if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev: for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, tasks.append(
send_partial(d,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_prev,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_prev_group, group=_hcg.send_prev_group,
use_calc_stream=False) use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple): if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev: for d in tensor_recv_prev:
recv_partial(d, tasks.append(
recv_partial(d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True))
tasks.append(
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True))
else:
tasks.append(
recv_partial(tensor_recv_prev,
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=True) use_calc_stream=True))
allgather_partial(d, tasks.append(
allgather_partial(tensor_recv_prev,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True))
else:
recv_partial(tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
if tensor_send_next is not None: if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple): if isinstance(tensor_send_next, tuple):
for d in tensor_send_next: for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, tasks.append(
send_partial(d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
tasks.append(
send_partial(tensor_send_next,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_next_group, group=_hcg.send_next_group,
use_calc_stream=False) use_calc_stream=False))
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
if tensor_recv_next is not None: if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple): if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next: for d in tensor_recv_next:
recv_partial(d, tasks.append(
recv_partial(d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True))
tasks.append(
allgather_partial(d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True))
else:
tasks.append(
recv_partial(tensor_recv_next,
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_next_group, group=_hcg.recv_next_group,
use_calc_stream=True) use_calc_stream=True))
allgather_partial(d,
tasks.append(
allgather_partial(tensor_recv_next,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True))
if in_dygraph_mode():
else: # wait tasks in new dygraph mode with new comm library
recv_partial(tensor_recv_next, for task in tasks:
src=1, if task is not None:
nranks=mp_degree, task.wait()
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(): def recv_forward(pp_first_stage):
if _hcg.is_first_stage: if pp_first_stage:
input_tensor = None input_tensor = None
else: else:
if not _send_recv_meta.has_recv_meta: if not _send_recv_meta.has_recv_meta:
...@@ -435,8 +452,8 @@ def recv_forward(): ...@@ -435,8 +452,8 @@ def recv_forward():
return input_tensor return input_tensor
def recv_backward(): def recv_backward(pp_last_stage):
if _hcg.is_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=None, _, output_tensor_grad = _p2p_helper(tensor_send_next=None,
...@@ -446,8 +463,8 @@ def recv_backward(): ...@@ -446,8 +463,8 @@ def recv_backward():
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor): def send_forward(output_tensor, pp_last_stage):
if not _hcg.is_last_stage: if not pp_last_stage:
if not _send_recv_meta.has_send_meta: if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor) _send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
...@@ -459,16 +476,16 @@ def send_forward(output_tensor): ...@@ -459,16 +476,16 @@ def send_forward(output_tensor):
recv_next=False) recv_next=False)
def send_backward(input_tensor_grad): def send_backward(input_tensor_grad, pp_first_stage):
if not _hcg.is_first_stage: if not pp_first_stage:
_p2p_helper(tensor_send_next=None, _p2p_helper(tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False)
def send_forward_recv_backward(output_tensor): def send_forward_recv_backward(output_tensor, pp_last_stage):
if _hcg.is_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor, _, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor,
...@@ -478,8 +495,8 @@ def send_forward_recv_backward(output_tensor): ...@@ -478,8 +495,8 @@ def send_forward_recv_backward(output_tensor):
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad): def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
if _hcg.is_first_stage: if pp_first_stage:
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _p2p_helper(tensor_send_next=None, input_tensor, _ = _p2p_helper(tensor_send_next=None,
...@@ -487,3 +504,48 @@ def send_backward_recv_forward(input_tensor_grad): ...@@ -487,3 +504,48 @@ def send_backward_recv_forward(input_tensor_grad):
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False)
return input_tensor return input_tensor
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad, recv_prev,
recv_next):
# always have to send dytpe info to downstream
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
return input_tensor, output_tensor_grad
def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = _use_cache
if recv_prev and not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next):
_, output_tensor_grad = _p2p_helper(tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
return output_tensor_grad
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
from .base import topology as tp from .base import topology as tp
from .base.topology import ParallelMode from .base.topology import ParallelMode
from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import TensorParallel, model_parallel_random_seed
from .meta_parallel import PipelineParallel, ShardingParallel from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer
from paddle.fluid import core from paddle.fluid import core
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
...@@ -185,6 +185,16 @@ def distributed_model(model): ...@@ -185,6 +185,16 @@ def distributed_model(model):
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
model = TensorParallel(model, fleet_env._hcg, strategy=strategy) model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
model = PipelineParallel(model, fleet_env._hcg, strategy=strategy) assert isinstance(
model, PipelineLayer
), "For pipeline parallel, the model should an instance of PipelineLayer"
if model.get_num_virtual_stages() == 1:
# 1f1b pipeline
model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)
else:
# interleave pipeline
model = PipelineParallelWithInterleave(model,
fleet_env._hcg,
strategy=strategy)
return model return model
...@@ -27,8 +27,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) ...@@ -27,8 +27,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer)
list(APPEND DIST_TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard)
list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load)
list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert)
...@@ -178,8 +176,6 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -178,8 +176,6 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
# TODO(shenliang03): batch_fc_op support CPU device in future # TODO(shenliang03): batch_fc_op support CPU device in future
# TODO(Yancey1989): parallel dygraph support CPU device in future # TODO(Yancey1989): parallel dygraph support CPU device in future
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
list(REMOVE_ITEM TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage)
list(REMOVE_ITEM TEST_OPS test_fleet_base_single) list(REMOVE_ITEM TEST_OPS test_fleet_base_single)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
...@@ -1178,9 +1174,6 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) ...@@ -1178,9 +1174,6 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE if(WITH_DISTRIBUTE
AND WITH_GPU AND WITH_GPU
AND WITH_NCCL) AND WITH_NCCL)
set_tests_properties(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT 500)
set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120)
......
...@@ -204,6 +204,20 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT) ...@@ -204,6 +204,20 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT)
set_tests_properties(test_parallel_dygraph_pipeline_parallel set_tests_properties(test_parallel_dygraph_pipeline_parallel
PROPERTIES TIMEOUT "500") PROPERTIES TIMEOUT "500")
endif() endif()
if((WITH_GPU) AND LOCAL_ALL_PLAT)
bash_test_modules(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
START_BASH
../../dist_test.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT "500" RUN_SERIAL 1)
endif()
if((WITH_GPU if((WITH_GPU
OR WITH_XPU OR WITH_XPU
OR WITH_ASCEND OR WITH_ASCEND
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.nn as nn import paddle.nn as nn
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer, PipelineParallelWithInterleave
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -87,7 +87,8 @@ class TestPipeLayerAPI(unittest.TestCase): ...@@ -87,7 +87,8 @@ class TestPipeLayerAPI(unittest.TestCase):
try: try:
model_chunks[0](paddle.to_tensor([1., 2.])) model_chunks[0](paddle.to_tensor([1., 2.]))
except NotImplementedError: raise NotImplementedError
except PermissionError:
pass pass
# fake call for the forward function of virtual pipeline layer # fake call for the forward function of virtual pipeline layer
...@@ -102,6 +103,7 @@ class TestPipeLayerAPI(unittest.TestCase): ...@@ -102,6 +103,7 @@ class TestPipeLayerAPI(unittest.TestCase):
# just make sure the model can be wrapped with distributed model # just make sure the model can be wrapped with distributed model
dist_model = fleet.distributed_model(pipe_model) dist_model = fleet.distributed_model(pipe_model)
assert isinstance(dist_model, PipelineParallelWithInterleave)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.fluid import layers
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc
from paddle.fluid.dygraph.layers import Layer
import paddle.nn as nn
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 8
length = 8
micro_batch_size = 2
num_virtual_pipeline_stages = 2
vocab_size = 128
hidden_size = 16
d_model = hidden_size
dim_feedforward = 4 * d_model
class EmbeddingNet(Layer):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
attention_mask = paddle.tensor.triu((paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
no_used = paddle.ones((3, 3), dtype="int32")
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
attention_mask.stop_gradient = True
no_used.stop_gradient = True
# need to fix bug of backward()
return w_emb, attention_mask, no_used, p_emb
class TransformerNet(Layer):
def __init__(self):
super(TransformerNet, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product + mask)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
tgt = residual + tgt
out = self.linear2(F.gelu(self.linear1(tgt), approximate=True))
return out
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask, no_used, p_emb = args[0], args[1], args[2], args[3]
output = super().forward(x, mask)
output = output + p_emb
mask.stop_gradient = True
return output, mask, no_used, p_emb
class CriterionPipe(Layer):
def __init__(self):
super(CriterionPipe, self).__init__()
def forward(self, out, label):
loss = out.mean()
return loss
class ModelPipe(PipelineLayer):
def __init__(self, topology):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(8):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
super().__init__(
layers=self.descs,
loss_fn=CriterionPipe(),
topology=topology,
num_virtual_pipeline_stages=num_virtual_pipeline_stages,
seg_method="layer:TransformerNetPipe")
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
values=[0.001, 0.002],
verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
unittest.main()
...@@ -25,8 +25,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus): ...@@ -25,8 +25,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def test_hybrid_parallel_pp_layer_with_virtual_stage(self): def test_hybrid_parallel_pp_layer_with_virtual_stage(self):
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py') self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py',
eager_mode=False) def test_hybrid_parallel_pp_transformer_with_virtual_stage(self):
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_with_virtual_stage.py')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,6 +16,7 @@ test_fleet_graph_execution_meta_optimizer,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,../../ ...@@ -16,6 +16,7 @@ test_fleet_graph_execution_meta_optimizer,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,../../
test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_fleet_graph_executor,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_graph_executor,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,1,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册