diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a6e8661f7a6eae312220a886fa6e0990acb164e8..21770fff656ad25509a1571c0b6acee71ad81e39 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -377,7 +377,7 @@ class PipelineLayer(Layer): for virtual_pp_rank in range(self._num_virtual_pipeline_stages): # Mapping the virtual pipeline stage to the real pipeline stage. # start_idx marks the start of a new virtual pp stage. - start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages + start_idx = virtual_pp_rank * self._num_stages for stage in range(self._num_stages): # stage mark the real pp stage if self.segment_parts[start_idx + @@ -483,7 +483,7 @@ class PipelineLayer(Layer): ", ".join(str(arg) for arg in self.segment_parts)) for i in range(self._stage_id, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage. @@ -528,7 +528,7 @@ class PipelineLayer(Layer): stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( stage) for i in range(stage, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): stage_to_virtual_stage_info += " {},".format(i) logger.info(stage_to_virtual_stage_info)