diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index a7ce608c57abc95c0e8d0d8652e00ab18b411783..e3e3e33bb49100613ee1e52228a27cbaf6724d8f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -777,9 +777,9 @@ class PipelineParallelWithInterleave(PipelineParallel): if self._comm_overlap: self._backward_step_count += 1 sync_step = self._backward_step_count - self.stage_id - if sync_step > 0 and sync_step % self.accumulate_steps == 0: + if sync_step > 0 and sync_step % self.num_stages == 0: chunk_idx = self._virtual_pp_world_size - ( - sync_step // self.accumulate_steps + sync_step // self.num_stages ) for buffer in self._chunk_2_comm_buffers[chunk_idx]: buffer.comm_grads() @@ -787,7 +787,7 @@ class PipelineParallelWithInterleave(PipelineParallel): if self.stage_id != 0: if ( self._backward_step_count - == self.accumulate_steps * self._virtual_pp_world_size + == self.num_stages * self.num_model_chunks ): for buffer in self._chunk_2_comm_buffers[0]: buffer.comm_grads() @@ -796,11 +796,10 @@ class PipelineParallelWithInterleave(PipelineParallel): if self._comm_overlap: assert ( self._backward_step_count - == self.accumulate_steps * self._virtual_pp_world_size - ), "backward step count should be equal to accumulate steps * " - "virtual pp world size, but get {}, excepted result is {}".format( - self._backward_step_count, - self.accumulate_steps * self._virtual_pp_world_size, + == self.num_stages * self.num_model_chunks + ), ( + "backward step count should be equal to accumulate steps * virtual pp world size," + f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}" ) for _, buffers in self._chunk_2_comm_buffers.items(): @@ -863,7 +862,18 @@ class PipelineParallelWithInterleave(PipelineParallel): self._forward_only = forward_only # store the number of backward steps - self._backward_step_count = 0 + + assert ( + self.accumulate_steps % self.num_stages == 0 + ), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format( + self.accumulate_steps, self.num_stages + ) + per_stage_accumulate_steps = self.accumulate_steps // self.num_stages + self._backward_step_count = ( + -(per_stage_accumulate_steps - 1) + * self.num_stages + * self.num_model_chunks + ) # init some data buffers for interleave scheduler self.input_tensors = [[] for _ in range(self.num_model_chunks)]