From 264ff9efb905afbc9bea08ac771839a7183b6a94 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 1 Sep 2021 19:46:14 +0800 Subject: [PATCH] [HybridParallel]Support finetinue model for PipelineParallel (#35287) * add cache for send_recv * add eval_batch for pipeline * add eval batch for pipelineparallel * add style code --- .../framework/distributed_strategy.proto | 1 + .../fleet/meta_parallel/pipeline_parallel.py | 83 +++++++++++++++---- .../pp_utils/p2p_communication.py | 10 ++- .../hybrid_parallel_pp_transformer.py | 7 +- 4 files changed, 80 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 58ae35f2689..3627a8cf71c 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -158,6 +158,7 @@ message PipelineConfig { optional int32 micro_batch_size = 1 [ default = 1 ]; optional int32 accumulate_steps = 2 [ default = 1 ]; optional string schedule_mode = 3 [ default = '1F1B' ]; + optional bool p2p_cache_shape = 4 [ default = true ]; } message TensorParallelConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index fc7b39ede24..706d64d8d35 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -42,11 +42,13 @@ class PipelineParallel(MetaParallelBase): self.accumulate_steps = self._strategy.pipeline_configs[ 'accumulate_steps'] + self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] + self.num_stages = self._hcg.get_pipe_parallel_world_size() self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() - p2p.initialize_p2p_groups(hcg) + p2p.initialize_p2p_groups(hcg, self._using_cache) _initialize_recompute_hcg(hcg) @@ -55,6 +57,8 @@ class PipelineParallel(MetaParallelBase): self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 + self._compute_loss = True + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -85,6 +89,7 @@ class PipelineParallel(MetaParallelBase): self.lr_scheduler = lr_scheduler self.scaler = scaler self.data = data + self._compute_loss = True self._layers.train() @@ -151,12 +156,57 @@ class PipelineParallel(MetaParallelBase): self._layers.allreduce_shared_weight_gradients() - self.train_loss = self._reduce_final_loss() + self.train_loss = self._broadcast_final_loss() # optimizer self._optimizer_step() return self.train_loss + def eval_batch(self, data, compute_loss=False): + self._layers.eval() + self._compute_loss = compute_loss + + # save data for eval + self.data = data + # store data id for micro_batch + self.micro_batch_id = 0 + + # store total loss of entire batch + self.total_loss = None + + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps + + input_buffers = [] + output_buffers = [] + + for step_id in range(startup_steps): + input_tensor = p2p.recv_forward() + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if steady_steps > 0: + input_tensor = p2p.recv_forward() + + for i in range(steady_steps): + last_iter = (i == (steady_steps - 1)) + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if not last_iter: + input_tensor = p2p.recv_forward() + + return self.total_loss if self._compute_loss else output_buffers + def _forward_step(self, input_tensor): if self.stage_id == 0: input_tensor = self._load_micro_batch(self.micro_batch_id) @@ -164,18 +214,21 @@ class PipelineParallel(MetaParallelBase): output_tensor = self._layers.forward(input_tensor) if self.is_last_stage: - labels = self._load_micro_batch(self.micro_batch_id) - output_tensor = self._layers._loss_fn(output_tensor, labels) - assert isinstance( - output_tensor, paddle. - Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype" - - if self.accumulate_steps > 1: - output_tensor = output_tensor / self.accumulate_steps - - if self.total_loss is None: - self.total_loss = paddle.zeros_like(output_tensor) - self.total_loss += output_tensor.detach() + # train calculate loss for train + if self._compute_loss: + assert self._layers._loss_fn is not None, "loss function should exist to compute loss" + labels = self._load_micro_batch(self.micro_batch_id) + output_tensor = self._layers._loss_fn(output_tensor, labels) + assert isinstance( + output_tensor, paddle.Tensor + ), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() self.micro_batch_id += 1 return output_tensor @@ -245,7 +298,7 @@ class PipelineParallel(MetaParallelBase): # No data input is required for other stages inputs = None - def _reduce_final_loss(self): + def _broadcast_final_loss(self): if self.is_last_stage: assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c508c88015c..e2c99edac12 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -19,11 +19,13 @@ import numpy as np from paddle import _C_ops _hcg = None +_use_cache = False -def initialize_p2p_groups(hcg): - global _hcg +def initialize_p2p_groups(hcg, use_cache=True): + global _hcg, _use_cache _hcg = hcg + _use_cache = use_cache send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( ) @@ -372,7 +374,7 @@ def recv_forward(): else: if not _send_recv_meta.has_recv_meta: _send_recv_meta.recv_meta(_hcg.recv_prev_group) - _send_recv_meta.has_recv_meta = True + _send_recv_meta.has_recv_meta = _use_cache input_tensor, _ = _p2p_helper( tensor_send_next=None, @@ -399,7 +401,7 @@ def send_forward(output_tensor): if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) - _send_recv_meta.has_send_meta = True + _send_recv_meta.has_send_meta = _use_cache _p2p_helper( tensor_send_next=output_tensor, diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index 524099c6ab0..c4c1e565068 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -177,10 +177,13 @@ class TestDistPPTraning(unittest.TestCase): x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) x = paddle.to_tensor(x_data) x.stop_gradient = True + + e_loss = model.eval_batch([x, x], True) loss = model.train_batch([x, x], optimizer, scheduler) - # TODO(shenliang03) add utest for loss - print("loss: ", loss) + # TODO(shenliang03) add utest for loss + if pp_id != 0: + np.testing.assert_allclose(loss.numpy(), e_loss.numpy()) if __name__ == "__main__": -- GitLab