From b51c3ff8cccec868362a866abc6993695831a8ba Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 9 Sep 2022 19:14:07 +0800 Subject: [PATCH] bug fix for virtual pipeline parallel (#45922) --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a6e8661f7a6..21770fff656 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -377,7 +377,7 @@ class PipelineLayer(Layer): for virtual_pp_rank in range(self._num_virtual_pipeline_stages): # Mapping the virtual pipeline stage to the real pipeline stage. # start_idx marks the start of a new virtual pp stage. - start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages + start_idx = virtual_pp_rank * self._num_stages for stage in range(self._num_stages): # stage mark the real pp stage if self.segment_parts[start_idx + @@ -483,7 +483,7 @@ class PipelineLayer(Layer): ", ".join(str(arg) for arg in self.segment_parts)) for i in range(self._stage_id, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage. @@ -528,7 +528,7 @@ class PipelineLayer(Layer): stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( stage) for i in range(stage, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): stage_to_virtual_stage_info += " {},".format(i) logger.info(stage_to_virtual_stage_info) -- GitLab