未验证 提交 b51c3ff8 编写于 作者: Y Yuang Liu 提交者: GitHub

bug fix for virtual pipeline parallel (#45922)

上级 38edea9a
...@@ -377,7 +377,7 @@ class PipelineLayer(Layer): ...@@ -377,7 +377,7 @@ class PipelineLayer(Layer):
for virtual_pp_rank in range(self._num_virtual_pipeline_stages): for virtual_pp_rank in range(self._num_virtual_pipeline_stages):
# Mapping the virtual pipeline stage to the real pipeline stage. # Mapping the virtual pipeline stage to the real pipeline stage.
# start_idx marks the start of a new virtual pp stage. # start_idx marks the start of a new virtual pp stage.
start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages start_idx = virtual_pp_rank * self._num_stages
for stage in range(self._num_stages): for stage in range(self._num_stages):
# stage mark the real pp stage # stage mark the real pp stage
if self.segment_parts[start_idx + if self.segment_parts[start_idx +
...@@ -483,7 +483,7 @@ class PipelineLayer(Layer): ...@@ -483,7 +483,7 @@ class PipelineLayer(Layer):
", ".join(str(arg) for arg in self.segment_parts)) ", ".join(str(arg) for arg in self.segment_parts))
for i in range(self._stage_id, self._total_stages_with_virtual_stages, for i in range(self._stage_id, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages): self._num_stages):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
...@@ -528,7 +528,7 @@ class PipelineLayer(Layer): ...@@ -528,7 +528,7 @@ class PipelineLayer(Layer):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
stage) stage)
for i in range(stage, self._total_stages_with_virtual_stages, for i in range(stage, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages): self._num_stages):
stage_to_virtual_stage_info += " {},".format(i) stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info) logger.info(stage_to_virtual_stage_info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册