From 72b5b5bf4b1e6e0bd70a2cb02b59a35dbdd84b0e Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 6 Sep 2022 12:00:07 +0800 Subject: [PATCH] [dygraph hybrid pp for interleave] The interleave scheduler for pipeline parallel (#45497) --- python/paddle/distributed/collective.py | 4 +- .../paddle/distributed/fleet/base/topology.py | 12 + .../fleet/meta_parallel/__init__.py | 1 + .../parallel_layers/pp_layers.py | 5 +- .../fleet/meta_parallel/pipeline_parallel.py | 392 ++++++++++++++++-- .../pp_utils/p2p_communication.py | 276 +++++++----- python/paddle/distributed/fleet/model.py | 14 +- .../fluid/tests/unittests/CMakeLists.txt | 7 - .../unittests/collective/fleet/CMakeLists.txt | 14 + ...id_parallel_pp_layer_with_virtual_stage.py | 6 +- ...allel_pp_transformer_with_virtual_stage.py | 195 +++++++++ ...ph_pipeline_parallel_with_virtual_stage.py | 6 +- .../unittests/collective/fleet/testslist.csv | 1 + 13 files changed, 783 insertions(+), 150 deletions(-) rename python/paddle/fluid/tests/unittests/{ => collective/fleet}/hybrid_parallel_pp_layer_with_virtual_stage.py (95%) create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py rename python/paddle/fluid/tests/unittests/{ => collective/fleet}/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py (86%) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 9900195c20..5960be4800 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -2384,7 +2384,7 @@ def isend(tensor, dst, group=None): assert group_dst_rank >= 0, ("dst rank out of group, need global rank") return group.process_group.send(tensor, group_dst_rank) else: - raise RuntimeError("Don't support static graph mode currently.") + raise RuntimeError("Only support eager dygraph mode.") def irecv(tensor, src=None, group=None): @@ -2433,7 +2433,7 @@ def irecv(tensor, src=None, group=None): assert group_src_rank >= 0, ("src rank out of group, need global rank") return group.process_group.recv(tensor, group_src_rank) else: - raise RuntimeError("Don't support static graph mode currently.") + raise RuntimeError("Only support eager dygraph mode.") class P2POp(object): diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index aef9c85adf..bbaca89512 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -240,6 +240,14 @@ class HybridCommunicateGroup(object): return parallel_group, parallel_comm_group + def _get_p2p_next_rank(self): + assert hasattr(self, 'next_rank'), "next_rank has not been inited" + return self.next_rank + + def _get_p2p_prev_rank(self): + assert hasattr(self, 'prev_rank'), "prev_rank has not been inited" + return self.prev_rank + def _set_p2p_group(self): comm_lists = self._topo.get_comm_list('pipe') @@ -255,6 +263,10 @@ class HybridCommunicateGroup(object): next_rank = comm_ranks[(idx + 1) % self._pp_degree] prev_rank = comm_ranks[(idx - 1) % self._pp_degree] + if self.global_rank == curr_rank: + self.next_rank = next_rank + self.prev_rank = prev_rank + next_group = paddle.distributed.new_group( ranks=[curr_rank, next_rank]) if self.global_rank == curr_rank: diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index fe7f23f3d8..f507e2f636 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -24,6 +24,7 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401 from .parallel_layers import get_rng_state_tracker # noqa: F401 from .tensor_parallel import TensorParallel # noqa: F401 from .pipeline_parallel import PipelineParallel # noqa: F401 +from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401 from .sharding_parallel import ShardingParallel # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 4d40d0e7de..3b1d313b6e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -189,7 +189,7 @@ class PipelineLayerChunk(Layer): # Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute # are in the forward function of PipelineLayer. Any directly call will bring unexpected # behavior under recompute circumstance. - raise NotImplementedError( + raise PermissionError( "The forward function of PipelineLayerChunk cannot be called directly. " "Please call forward function of PipelineLayer.") @@ -385,6 +385,9 @@ class PipelineLayer(Layer): start_idx + stage + 1]: return stage + def get_num_virtual_stages(self): + return self._num_virtual_pipeline_stages + def get_model_chunks(self): return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 3135c5379e..876f9ffaed 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -22,6 +22,7 @@ from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_sharding_parameters from ..utils.log_util import logger from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler +import paddle.fluid.framework as framework from .pp_utils import p2p_communication as p2p import paddle.fluid.core as core @@ -53,12 +54,15 @@ class PipelineParallel(MetaParallelBase): self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() + self._virtual_pp_world_size = None + self._virtual_pp_rank = None + self._real_pp_world_size = self.num_stages + self._real_pp_rank = self.stage_id + p2p.initialize_p2p_groups(hcg, self._using_cache) _initialize_recompute_hcg(hcg) - self.is_first_stage = self.stage_id == 0 - self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 @@ -79,6 +83,28 @@ class PipelineParallel(MetaParallelBase): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) + def is_pipeline_first_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self._virtual_pp_world_size is not None: + assert self._virtual_pp_rank is not None + if self._virtual_pp_rank != 0: + return False + assert self._real_pp_rank is not None + return self._real_pp_rank == 0 + + def is_pipeline_last_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self._virtual_pp_world_size is not None: + assert self._virtual_pp_rank is not None + if self._virtual_pp_rank != (self._virtual_pp_world_size - 1): + return False + assert self._real_pp_rank is not None + assert self._real_pp_world_size is not None + return self._real_pp_rank == (self._real_pp_world_size - 1) + + def set_virtual_pipeline_rank(self, rank): + self._virtual_pp_rank = rank + def forward_backward_pipeline(self, data, scaler=None): # use the 1f1b scheduling strategy. # this strategy is inspired by: @@ -103,23 +129,24 @@ class PipelineParallel(MetaParallelBase): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward() + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) output_tensor = self._forward_step(input_tensor) - p2p.send_forward(output_tensor) + p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward() + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) output_tensor = self._forward_step(input_tensor) - output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) + output_tensor_grad = p2p.send_forward_recv_backward( + output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) output_buffers.append(output_tensor) @@ -132,33 +159,41 @@ class PipelineParallel(MetaParallelBase): if last_iter: input_tensor = None - p2p.send_backward(input_tensor_grad) + p2p.send_backward(input_tensor_grad, + self.is_pipeline_first_stage()) else: - input_tensor = p2p.send_backward_recv_forward(input_tensor_grad) + input_tensor = p2p.send_backward_recv_forward( + input_tensor_grad, self.is_pipeline_first_stage()) for i in range(startup_steps): input_tensor = input_buffers.pop(0) output_tensor = output_buffers.pop(0) - output_tensor_grad = p2p.recv_backward() + output_tensor_grad = p2p.recv_backward( + self.is_pipeline_last_stage()) input_tensor_grad = self._backward_step(input_tensor, output_tensor, output_tensor_grad) - p2p.send_backward(input_tensor_grad) + p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) self._layers.allreduce_shared_weight_gradients() with paddle.amp.auto_cast(enable=False): train_loss = self._broadcast_final_loss() return train_loss - def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): + def _prepare_training(self, data, optimizer, lr_scheduler): + # reset the virtual pp rank for each run + self.set_virtual_pipeline_rank(0) + assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') - if self.is_first_stage or self.is_last_stage: + if self.is_pipeline_first_stage( + ignore_virtual=True) or self.is_pipeline_last_stage( + ignore_virtual=True): assert data is not None, ( "For the first and the last stage, the data must be set.") else: @@ -169,7 +204,11 @@ class PipelineParallel(MetaParallelBase): self._layers.train() - # 1f1b for pipeline + return data + + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): + data = self._prepare_training(data, optimizer, lr_scheduler) + # 1f1b scheduler for pipeline parallel train_loss = self.forward_backward_pipeline(data, scaler) # optimizer @@ -179,6 +218,9 @@ class PipelineParallel(MetaParallelBase): return train_loss def eval_batch(self, data, compute_loss=False): + # reset the virtual pp rank for each run + self.set_virtual_pipeline_rank(0) + self._layers.eval() self._compute_loss = compute_loss @@ -198,28 +240,28 @@ class PipelineParallel(MetaParallelBase): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward() + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) output_tensor = self._forward_step(input_tensor) - p2p.send_forward(output_tensor) + p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward() + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) output_tensor = self._forward_step(input_tensor) - p2p.send_forward(output_tensor) + p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) output_buffers.append(output_tensor) if not last_iter: - input_tensor = p2p.recv_forward() + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) if self._compute_loss: self.train_loss = self._broadcast_final_loss() @@ -228,13 +270,15 @@ class PipelineParallel(MetaParallelBase): return self.train_loss - def _forward_step(self, input_tensor): - if self.stage_id == 0: + def _forward_step(self, input_tensor, chunk_id=None): + if self.is_pipeline_first_stage(): input_tensor = self._load_micro_batch(self.micro_batch_id) - output_tensor = self._layers.forward(input_tensor) + assert chunk_id is None or isinstance(chunk_id, int) + + output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id) - if self.is_last_stage: + if self.is_pipeline_last_stage(): # train calculate loss for train if self._compute_loss: assert self._layers._loss_fn is not None, "loss function should exist to compute loss" @@ -253,12 +297,15 @@ class PipelineParallel(MetaParallelBase): self.total_loss = paddle.zeros_like(output_tensor) self.total_loss += output_tensor.detach() - self.micro_batch_id += 1 + if self.is_pipeline_first_stage() or self.is_pipeline_last_stage(): + # Only increase micro batch id at virtual first/last pp stage. + # The micro batch id is used to load data, therefore, only increase it when load data. + self.micro_batch_id += 1 return output_tensor def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): with paddle.amp.auto_cast(enable=False): - if self.is_last_stage: + if self.is_pipeline_last_stage(): assert output_tensor_grad is None if self.scaler: paddle.autograd.backward(self.scaler.scale(output_tensor)) @@ -289,7 +336,8 @@ class PipelineParallel(MetaParallelBase): begin = cache_id * self.micro_batch_size end = begin + self.micro_batch_size - if self.is_first_stage: + # The virtual first and last pipeline stage need data, all others don't need. + if self.is_pipeline_first_stage(): assert len(inputs) == 2, "length of input should be 2" if isinstance(inputs[0], tuple): assert len( @@ -307,7 +355,7 @@ class PipelineParallel(MetaParallelBase): batch_size = inputs[0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size return inputs[0][begin:end, :].detach() - elif self.is_last_stage: + elif self.is_pipeline_last_stage(): assert len(inputs) == 2, "length of input should be 2" if isinstance(inputs[1], tuple): batch_size = inputs[1][0].shape[0] @@ -323,7 +371,9 @@ class PipelineParallel(MetaParallelBase): inputs = None def _broadcast_final_loss(self): - if self.is_last_stage: + # Since the last backward run in interleave will set the virtual rank to 0, + # here we need to check last stage ignoring virtual stage. + if self.is_pipeline_last_stage(ignore_virtual=True): assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() is_fp32 = paddle.to_tensor( @@ -364,3 +414,291 @@ class PipelineParallel(MetaParallelBase): self.optimizer.clear_grad() if self.lr_scheduler: self.lr_scheduler.step() + + +class PipelineParallelWithInterleave(PipelineParallel): + # pipeline parallel with interleave scheduler + + def __init__(self, layers, hcg, strategy): + super(PipelineParallelWithInterleave, self).__init__(layers=layers, + hcg=hcg, + strategy=strategy) + assert layers.get_num_virtual_stages() > 1 + assert framework.in_dygraph_mode( + ), "virtual pipeline stage with interleave only support eager dygraph mode" + # setup for interleave scheduler + self.num_model_chunks = layers.get_num_virtual_stages() + self.model_chunks = layers.get_model_chunks() + assert self.model_chunks is not None + assert len(self.model_chunks) == self.num_model_chunks + self._virtual_pp_world_size = self.num_model_chunks + self._virtual_pp_rank = 0 + + def _get_virtual_pp_rank(self, micro_step, forward): + virtual_pp_stage = micro_step % (self.num_stages * + self.num_model_chunks) + virtual_pp_stage = virtual_pp_stage // self.num_stages + if not forward: + virtual_pp_stage = (self.num_model_chunks - virtual_pp_stage - 1) + return virtual_pp_stage + + def _forward_step_helper(self, micro_step): + virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) + self.set_virtual_pipeline_rank(virtual_pp_rank) + + # some checkers + assert hasattr(self, 'input_tensors') + assert hasattr(self, 'output_tensors') + if not self._forward_only: + assert hasattr(self, 'output_tensor_grads') + + if self.is_pipeline_first_stage(): + if len(self.input_tensors[virtual_pp_rank]) == len( + self.output_tensors[virtual_pp_rank]): + self.input_tensors[virtual_pp_rank].append(None) + input_tensor = self.input_tensors[virtual_pp_rank][-1] + output_tensor = self._forward_step(input_tensor, virtual_pp_rank) + self.output_tensors[virtual_pp_rank].append(output_tensor) + + if self._forward_only: + # no need to store tensor for backward + self.input_tensors[virtual_pp_rank].pop() + self.output_tensors[virtual_pp_rank].pop() + + return output_tensor + + def _backward_step_helper(self, micro_step): + virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) + self.set_virtual_pipeline_rank(virtual_pp_rank) + + # some checkers + assert hasattr(self, 'input_tensors') + assert hasattr(self, 'output_tensors') + assert hasattr(self, 'output_tensor_grads') + + if self.is_pipeline_last_stage(): + if len(self.output_tensor_grads[virtual_pp_rank]) == 0: + self.output_tensor_grads[virtual_pp_rank].append(None) + + input_tensor = self.input_tensors[virtual_pp_rank].pop(0) + output_tensor = self.output_tensors[virtual_pp_rank].pop(0) + output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0) + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) + + return input_tensor_grad + + def interleave_pipeline(self, + data, + scaler, + forward_only=False, + compute_loss=True): + # use interleave scheduling strategy. + # this strategy is inspired by: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py + if not compute_loss: + assert not forward_only, "compute_loss can only be set to False when forward_only is set to True" + + # init some attributes for this batch run + self.scaler = scaler + self.data = data + self.total_loss = None + self.micro_batch_id = 0 + self._forward_only = forward_only + + # init some data buffers for interleave scheduler + self.input_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensors = [[] for _ in range(self.num_model_chunks)] + self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + + num_steps = self.accumulate_steps * self.num_model_chunks + all_startup_steps = False + if forward_only: + # If only forward, since there is no backward during running, all steps are startup steps + startup_steps = num_steps + else: + if self.accumulate_steps == self.num_stages: + startup_steps = num_steps + all_startup_steps = True + else: + startup_steps = (self.num_stages - self.stage_id - 1) * 2 + startup_steps += (self.num_model_chunks - 1) * self.num_stages + startup_steps = min(startup_steps, num_steps) + + steady_steps = num_steps - startup_steps + + self.set_virtual_pipeline_rank(0) + self.input_tensors[0].append( + p2p.recv_forward(self.is_pipeline_first_stage())) + + # run startup steps + for micro_step in range(startup_steps): + output_tensor = self._forward_step_helper(micro_step) + + # determine whether recv forward tensor or not + next_virtual_pp_rank = self._get_virtual_pp_rank(micro_step + 1, + forward=True) + recv_prev = True + if self.is_pipeline_first_stage(ignore_virtual=True): + if next_virtual_pp_rank == 0: + # next chunk is the first chunk, not need to pre recv an input tensor + recv_prev = False + # last micro step, no next run + if micro_step == (num_steps - 1): + recv_prev = False + + # last stage shouldn't send tensor to downstream + if self.is_pipeline_last_stage(): + output_tensor = None + + if micro_step == (startup_steps - + 1) and not forward_only and not all_startup_steps: + input_tensor_grad = None + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + # the last startup step needs on four direction comm to set up for steady 1f1b + input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next) + self.output_tensor_grads[self.num_model_chunks - + 1].append(output_tensor_grad) + else: + input_tensor = p2p.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev) + self.input_tensors[next_virtual_pp_rank].append(input_tensor) + + # run 1f1b steady steps + for micro_step in range(steady_steps): + # forward + forward_micro_step_id = micro_step + startup_steps + output_tensor = self._forward_step_helper(forward_micro_step_id) + + # backward + backward_micro_step_id = micro_step + input_tensor_grad = self._backward_step_helper( + backward_micro_step_id) + + # four directions comm + # send output tensor to downstream + # send input tensor grad to upstream + # recv input tensor from upstream + # recv output tensor grad from downstream + + # last stage doesn't send rst to downstream + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) + if self.is_pipeline_last_stage(): + output_tensor = None + + # first stage doesn't send grad to upstream + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + if self.is_pipeline_first_stage(): + input_tensor_grad = None + + # determine whether to recv input tensor from upstream + recv_prev = True + if self.is_pipeline_first_stage(ignore_virtual=True): + next_forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id - (self.num_stages - 1), forward=True) + if next_forward_virtual_pp_rank == (self.num_model_chunks - 1): + # first pp stage and first virtual stage + recv_prev = False + next_forward_virtual_pp_rank += 1 + else: + next_forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True) + + # last iteration doesn't need recv from upstream + if micro_step == (steady_steps - 1): + recv_prev = False + + # determine whether to recv grad from downstream + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id - (self.num_stages - 1), + forward=False) + if next_backward_virtual_pp_rank == 0: + # last pp stage and last virtual stage + recv_next = False + next_backward_virtual_pp_rank -= 1 + else: + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id + 1, forward=False) + + input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next) + + if recv_prev: + self.input_tensors[next_forward_virtual_pp_rank].append( + input_tensor) + if recv_next: + self.output_tensor_grads[next_backward_virtual_pp_rank].append( + output_tensor_grad) + + # remaining backward steps + if not forward_only: + if all_startup_steps: + self.output_tensor_grads[self.num_model_chunks - 1].append( + p2p.recv_backward(self.is_pipeline_last_stage())) + + for micro_step in range(steady_steps, num_steps): + # cooldown loop + input_tensor_grad = self._backward_step_helper(micro_step) + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + micro_step + 1, forward=False) + + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_virtual_pp_rank == (self.num_model_chunks - + 1): + recv_next = False + + if micro_step == (num_steps - 1): + recv_next = False + + self.output_tensor_grads[next_backward_virtual_pp_rank].append( + p2p.send_backward_recv_backward(input_tensor_grad, + recv_next=recv_next)) + + self._layers.allreduce_shared_weight_gradients() + + if compute_loss: + # return loss if compute loss + with paddle.amp.auto_cast(enable=False): + train_loss = self._broadcast_final_loss() + else: + # else just return all intermediate output tensor for all micro steps + train_loss = self.output_tensors + + return train_loss + + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): + data = self._prepare_training(data, optimizer, lr_scheduler) + # interleave scheduler for pipeline parallel + train_loss = self.interleave_pipeline(data, scaler) + + # optimizer + with paddle.amp.auto_cast(enable=False): + self._optimizer_step() + + return train_loss + + def eval_batch(self, data, compute_loss=False): + # reset the virtual pp rank for each run + self.set_virtual_pipeline_rank(0) + + self._layers.eval() + self._compute_loss = compute_loss + + return self.interleave_pipeline(data, None, forward_only=True) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 14a2aa8448..9113603376 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -54,7 +54,7 @@ class SendRecvMeta: def _recv_shape_dtype(self, group): # recv len(shape) dims = paddle.to_tensor([0]) - src_rank = group.ranks[0] + src_rank = _hcg._get_p2p_prev_rank() paddle.distributed.recv(dims, src=src_rank, group=group) dims = dims.item() @@ -74,7 +74,7 @@ class SendRecvMeta: def recv_meta(self, group): tensor_type = paddle.to_tensor([0]) - src_rank = group.ranks[0] + src_rank = _hcg._get_p2p_prev_rank() paddle.distributed.recv(tensor_type, src=src_rank, group=group) tensor_type = tensor_type.item() @@ -105,7 +105,7 @@ class SendRecvMeta: def _send_dims_shape_dtype(self, tensor, group): # send len(shape) dims = paddle.to_tensor(len(tensor.shape)) - dst_rank = group.ranks[1] + dst_rank = _hcg._get_p2p_next_rank() paddle.distributed.send(dims, dst=dst_rank, group=group) @@ -122,7 +122,7 @@ class SendRecvMeta: paddle.distributed.send(stop_grad, dst=dst_rank, group=group) def send_meta(self, tensor, group): - dst_rank = group.ranks[1] + dst_rank = _hcg._get_p2p_next_rank() if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): tensor_type = paddle.to_tensor([0]) @@ -165,20 +165,17 @@ def _is_valid_send_recv_partial(tensor, mp_degree): def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id): + dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', dst, 'num', nranks, 'id', - rank_id) + 'peer', dst_rank_in_group, 'num', + nranks, 'id', rank_id) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.send_partial(tensor, dst, nranks, rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + return group.process_group.send_partial(tensor, dst_rank_in_group, + nranks, rank_id) def send_partial(tensor, @@ -192,33 +189,35 @@ def send_partial(tensor, return ring_id = 0 if group is None else group.id + dst_rank = _hcg._get_p2p_next_rank( + ) if dst == 1 else _hcg._get_p2p_prev_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, - nranks, rank_id) + return _partial_send_op(tensor, group, use_calc_stream, ring_id, + dst_rank, nranks, rank_id) else: - return paddle.distributed.send(tensor.detach(), - dst=group.ranks[dst], - group=group, - use_calc_stream=use_calc_stream) + if _in_legacy_dygraph(): + send_op = paddle.distributed.send + elif in_dygraph_mode(): + send_op = paddle.distributed.isend + return send_op(tensor.detach(), dst=dst_rank, group=group) def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): + src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', src, 'num', nranks, 'id', - rank_id, 'dtype', tensor.dtype, - 'out_shape', tensor.shape) + 'peer', src_rank_in_group, 'num', + nranks, 'id', rank_id, 'dtype', + tensor.dtype, 'out_shape', + tensor.shape) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.recv_partial(tensor, src, nranks, rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + return group.process_group.recv_partial(tensor, src_rank_in_group, + nranks, rank_id) def recv_partial(tensor, @@ -232,14 +231,18 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id + src_rank = _hcg._get_p2p_prev_rank( + ) if src == 0 else _hcg._get_p2p_next_rank() + if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, - nranks, rank_id) + return _partial_recv_op(tensor, group, use_calc_stream, ring_id, + src_rank, nranks, rank_id) else: - return paddle.distributed.recv(tensor.detach(), - src=group.ranks[src], - group=group, - use_calc_stream=use_calc_stream) + if _in_legacy_dygraph(): + recv_op = paddle.distributed.recv + elif in_dygraph_mode(): + recv_op = paddle.distributed.irecv + return recv_op(tensor.detach(), src=src_rank, group=group) def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, @@ -253,13 +256,8 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.all_gather_partial(tensor, tensor, nranks, + return group.process_group.all_gather_partial(tensor, tensor, nranks, rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task def allgather_partial(tensor, @@ -268,9 +266,9 @@ def allgather_partial(tensor, group=None, use_calc_stream=True): if not _is_valid_send_recv_partial(tensor, nranks): - return tensor + return None if group is not None and not group.is_member(): - return + return None ring_id = 0 if group is None else group.id return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, @@ -323,105 +321,124 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): tensor_recv_next = paddle.empty( shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) + # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops + tasks = [] # start to p2p communicate if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: paddle.distributed.wait(d, use_calc_stream=True) - send_partial(d, + tasks.append( + send_partial(d, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False)) + else: + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + tasks.append( + send_partial(tensor_send_prev, dst=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_prev_group, - use_calc_stream=False) - else: - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - send_partial(tensor_send_prev, - dst=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_prev_group, - use_calc_stream=False) + use_calc_stream=False)) if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - recv_partial(d, + tasks.append( + recv_partial(d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=True)) + tasks.append( + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True)) + else: + tasks.append( + recv_partial(tensor_recv_prev, src=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, - use_calc_stream=True) - allgather_partial(d, + use_calc_stream=True)) + tasks.append( + allgather_partial(tensor_recv_prev, nranks=mp_degree, rank_id=mp_rank, group=mp_group, - use_calc_stream=True) - else: - recv_partial(tensor_recv_prev, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=True) - allgather_partial(tensor_recv_prev, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + use_calc_stream=True)) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) - send_partial(d, + tasks.append( + send_partial(d, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False)) + else: + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + tasks.append( + send_partial(tensor_send_next, dst=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_next_group, - use_calc_stream=False) - else: - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - send_partial(tensor_send_next, - dst=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_next_group, - use_calc_stream=False) + use_calc_stream=False)) if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: - recv_partial(d, + tasks.append( + recv_partial(d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=True)) + tasks.append( + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True)) + + else: + tasks.append( + recv_partial(tensor_recv_next, src=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, - use_calc_stream=True) - allgather_partial(d, + use_calc_stream=True)) + + tasks.append( + allgather_partial(tensor_recv_next, nranks=mp_degree, rank_id=mp_rank, group=mp_group, - use_calc_stream=True) - - else: - recv_partial(tensor_recv_next, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=True) - - allgather_partial(tensor_recv_next, - nranks=mp_degree, - rank_id=mp_rank, - group=mp_group, - use_calc_stream=True) + use_calc_stream=True)) + if in_dygraph_mode(): + # wait tasks in new dygraph mode with new comm library + for task in tasks: + if task is not None: + task.wait() return tensor_recv_prev, tensor_recv_next -def recv_forward(): - if _hcg.is_first_stage: +def recv_forward(pp_first_stage): + if pp_first_stage: input_tensor = None else: if not _send_recv_meta.has_recv_meta: @@ -435,8 +452,8 @@ def recv_forward(): return input_tensor -def recv_backward(): - if _hcg.is_last_stage: +def recv_backward(pp_last_stage): + if pp_last_stage: output_tensor_grad = None else: _, output_tensor_grad = _p2p_helper(tensor_send_next=None, @@ -446,8 +463,8 @@ def recv_backward(): return output_tensor_grad -def send_forward(output_tensor): - if not _hcg.is_last_stage: +def send_forward(output_tensor, pp_last_stage): + if not pp_last_stage: if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) @@ -459,16 +476,16 @@ def send_forward(output_tensor): recv_next=False) -def send_backward(input_tensor_grad): - if not _hcg.is_first_stage: +def send_backward(input_tensor_grad, pp_first_stage): + if not pp_first_stage: _p2p_helper(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=False) -def send_forward_recv_backward(output_tensor): - if _hcg.is_last_stage: +def send_forward_recv_backward(output_tensor, pp_last_stage): + if pp_last_stage: output_tensor_grad = None else: _, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor, @@ -478,8 +495,8 @@ def send_forward_recv_backward(output_tensor): return output_tensor_grad -def send_backward_recv_forward(input_tensor_grad): - if _hcg.is_first_stage: +def send_backward_recv_forward(input_tensor_grad, pp_first_stage): + if pp_first_stage: input_tensor = None else: input_tensor, _ = _p2p_helper(tensor_send_next=None, @@ -487,3 +504,48 @@ def send_backward_recv_forward(input_tensor_grad): recv_prev=True, recv_next=False) return input_tensor + + +def send_forward_backward_recv_forward_backward(output_tensor, + input_tensor_grad, recv_prev, + recv_next): + # always have to send dytpe info to downstream + if not _send_recv_meta.has_send_meta: + _send_recv_meta.set_send_message(output_tensor) + _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) + _send_recv_meta.has_send_meta = _use_cache + if recv_prev and not _send_recv_meta.has_recv_meta: + _send_recv_meta.recv_meta(_hcg.recv_prev_group) + _send_recv_meta.has_recv_meta = _use_cache + input_tensor, output_tensor_grad = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next) + return input_tensor, output_tensor_grad + + +def send_forward_recv_forward(output_tensor, recv_prev): + # always have to send dytpe info to downstream + if not _send_recv_meta.has_send_meta: + _send_recv_meta.set_send_message(output_tensor) + _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) + _send_recv_meta.has_send_meta = _use_cache + if recv_prev and not _send_recv_meta.has_recv_meta: + _send_recv_meta.recv_meta(_hcg.recv_prev_group) + _send_recv_meta.has_recv_meta = _use_cache + + input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=recv_prev, + recv_next=False) + + return input_tensor + + +def send_backward_recv_backward(input_tensor_grad, recv_next): + _, output_tensor_grad = _p2p_helper(tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=recv_next) + return output_tensor_grad diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 988d2d928c..fea2614fe8 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -18,7 +18,7 @@ import numpy as np from .base import topology as tp from .base.topology import ParallelMode from .meta_parallel import TensorParallel, model_parallel_random_seed -from .meta_parallel import PipelineParallel, ShardingParallel +from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from paddle.fluid import core from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar @@ -185,6 +185,16 @@ def distributed_model(model): elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: - model = PipelineParallel(model, fleet_env._hcg, strategy=strategy) + assert isinstance( + model, PipelineLayer + ), "For pipeline parallel, the model should an instance of PipelineLayer" + if model.get_num_virtual_stages() == 1: + # 1f1b pipeline + model = PipelineParallel(model, fleet_env._hcg, strategy=strategy) + else: + # interleave pipeline + model = PipelineParallelWithInterleave(model, + fleet_env._hcg, + strategy=strategy) return model diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 18c0b12896..a76b9d1789 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -27,8 +27,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer) -list(APPEND DIST_TEST_OPS - test_parallel_dygraph_pipeline_parallel_with_virtual_stage) list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) @@ -178,8 +176,6 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) # TODO(shenliang03): batch_fc_op support CPU device in future # TODO(Yancey1989): parallel dygraph support CPU device in future list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) - list(REMOVE_ITEM TEST_OPS - test_parallel_dygraph_pipeline_parallel_with_virtual_stage) list(REMOVE_ITEM TEST_OPS test_fleet_base_single) list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) list(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) @@ -1178,9 +1174,6 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) - set_tests_properties( - test_parallel_dygraph_pipeline_parallel_with_virtual_stage - PROPERTIES TIMEOUT 500) set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index 83cb99a2e7..50aaf8f12c 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -204,6 +204,20 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT "500") endif() +if((WITH_GPU) AND LOCAL_ALL_PLAT) + bash_test_modules( + test_parallel_dygraph_pipeline_parallel_with_virtual_stage + START_BASH + ../../dist_test.sh + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties( + test_parallel_dygraph_pipeline_parallel_with_virtual_stage + PROPERTIES TIMEOUT "500" RUN_SERIAL 1) +endif() if((WITH_GPU OR WITH_XPU OR WITH_ASCEND diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_layer_with_virtual_stage.py similarity index 95% rename from python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py rename to python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_layer_with_virtual_stage.py index 1bd8e93480..137dde6891 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_layer_with_virtual_stage.py @@ -19,7 +19,7 @@ import paddle from paddle.distributed import fleet import paddle.nn as nn from paddle.fluid.dygraph.layers import Layer -from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer +from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer, PipelineParallelWithInterleave import paddle.nn.functional as F @@ -87,7 +87,8 @@ class TestPipeLayerAPI(unittest.TestCase): try: model_chunks[0](paddle.to_tensor([1., 2.])) - except NotImplementedError: + raise NotImplementedError + except PermissionError: pass # fake call for the forward function of virtual pipeline layer @@ -102,6 +103,7 @@ class TestPipeLayerAPI(unittest.TestCase): # just make sure the model can be wrapped with distributed model dist_model = fleet.distributed_model(pipe_model) + assert isinstance(dist_model, PipelineParallelWithInterleave) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py new file mode 100644 index 0000000000..47b3f3a550 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid import layers +import paddle.nn.functional as F +from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 8 +length = 8 +micro_batch_size = 2 +num_virtual_pipeline_stages = 2 +vocab_size = 128 +hidden_size = 16 +d_model = hidden_size +dim_feedforward = 4 * d_model + + +class EmbeddingNet(Layer): + + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, x): + attention_mask = paddle.tensor.triu((paddle.ones( + (length, length), dtype="float32") * -1e9), 1) + + no_used = paddle.ones((3, 3), dtype="int32") + + w_emb = self.word_embeddings(x) + p_emb = self.position_embeddings(x) + w_emb = w_emb + p_emb + + attention_mask.stop_gradient = True + no_used.stop_gradient = True + # need to fix bug of backward() + return w_emb, attention_mask, no_used, p_emb + + +class TransformerNet(Layer): + + def __init__(self): + super(TransformerNet, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) + + weights = F.softmax(product + mask) + tgt = layers.matmul(weights, v) + residual = tgt + tgt = self.norm1(tgt) + tgt = residual + tgt + + out = self.linear2(F.gelu(self.linear1(tgt), approximate=True)) + return out + + +class EmbeddingPipe(EmbeddingNet): + + def forward(self, x): + return super().forward(x) + + +class TransformerNetPipe(TransformerNet): + + def forward(self, args): + x, mask, no_used, p_emb = args[0], args[1], args[2], args[3] + + output = super().forward(x, mask) + output = output + p_emb + mask.stop_gradient = True + return output, mask, no_used, p_emb + + +class CriterionPipe(Layer): + + def __init__(self): + super(CriterionPipe, self).__init__() + + def forward(self, out, label): + loss = out.mean() + return loss + + +class ModelPipe(PipelineLayer): + + def __init__(self, topology): + self.descs = [] + self.descs.append(LayerDesc(EmbeddingPipe)) + + for x in range(8): + self.descs.append(LayerDesc(TransformerNetPipe)) + + self.descs.append(lambda x: x[0]) + + super().__init__( + layers=self.descs, + loss_fn=CriterionPipe(), + topology=topology, + num_virtual_pipeline_stages=num_virtual_pipeline_stages, + seg_method="layer:TransformerNetPipe") + + +class TestDistPPTraning(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2], + values=[0.001, 0.002], + verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + + e_loss = model.eval_batch([x, x], True) + loss = model.train_batch([x, x], optimizer, scheduler) + + np.testing.assert_allclose(loss.numpy(), e_loss.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py similarity index 86% rename from python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py rename to python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py index 7011b4507e..2330e0dac6 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py @@ -25,8 +25,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus): def test_hybrid_parallel_pp_layer_with_virtual_stage(self): self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py') - self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py', - eager_mode=False) + + def test_hybrid_parallel_pp_transformer_with_virtual_stage(self): + self.run_mnist_2gpu( + 'hybrid_parallel_pp_transformer_with_virtual_stage.py') if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv index cb5607325a..286b7dc911 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv @@ -16,6 +16,7 @@ test_fleet_graph_execution_meta_optimizer,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,../../ test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_fleet_graph_executor,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,1,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -- GitLab