From 7b7ec08eb285c91880864603a46b05e79f51841e Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 2 Aug 2023 17:05:09 +0800 Subject: [PATCH] [BugFix]Fix bug in vpp+ sharding/dp overlap (#55890) * fix bug * fix bug * fix bug * fix bug * fix bug --- .../fleet/meta_parallel/pipeline_parallel.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index a7ce608c57a..e3e3e33bb49 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -777,9 +777,9 @@ class PipelineParallelWithInterleave(PipelineParallel): if self._comm_overlap: self._backward_step_count += 1 sync_step = self._backward_step_count - self.stage_id - if sync_step > 0 and sync_step % self.accumulate_steps == 0: + if sync_step > 0 and sync_step % self.num_stages == 0: chunk_idx = self._virtual_pp_world_size - ( - sync_step // self.accumulate_steps + sync_step // self.num_stages ) for buffer in self._chunk_2_comm_buffers[chunk_idx]: buffer.comm_grads() @@ -787,7 +787,7 @@ class PipelineParallelWithInterleave(PipelineParallel): if self.stage_id != 0: if ( self._backward_step_count - == self.accumulate_steps * self._virtual_pp_world_size + == self.num_stages * self.num_model_chunks ): for buffer in self._chunk_2_comm_buffers[0]: buffer.comm_grads() @@ -796,11 +796,10 @@ class PipelineParallelWithInterleave(PipelineParallel): if self._comm_overlap: assert ( self._backward_step_count - == self.accumulate_steps * self._virtual_pp_world_size - ), "backward step count should be equal to accumulate steps * " - "virtual pp world size, but get {}, excepted result is {}".format( - self._backward_step_count, - self.accumulate_steps * self._virtual_pp_world_size, + == self.num_stages * self.num_model_chunks + ), ( + "backward step count should be equal to accumulate steps * virtual pp world size," + f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}" ) for _, buffers in self._chunk_2_comm_buffers.items(): @@ -863,7 +862,18 @@ class PipelineParallelWithInterleave(PipelineParallel): self._forward_only = forward_only # store the number of backward steps - self._backward_step_count = 0 + + assert ( + self.accumulate_steps % self.num_stages == 0 + ), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format( + self.accumulate_steps, self.num_stages + ) + per_stage_accumulate_steps = self.accumulate_steps // self.num_stages + self._backward_step_count = ( + -(per_stage_accumulate_steps - 1) + * self.num_stages + * self.num_model_chunks + ) # init some data buffers for interleave scheduler self.input_tensors = [[] for _ in range(self.num_model_chunks)] -- GitLab