未验证 提交 7b7ec08e 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix bug in vpp+ sharding/dp overlap (#55890)

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 42ab2c34
...@@ -777,9 +777,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -777,9 +777,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap: if self._comm_overlap:
self._backward_step_count += 1 self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.accumulate_steps == 0: if sync_step > 0 and sync_step % self.num_stages == 0:
chunk_idx = self._virtual_pp_world_size - ( chunk_idx = self._virtual_pp_world_size - (
sync_step // self.accumulate_steps sync_step // self.num_stages
) )
for buffer in self._chunk_2_comm_buffers[chunk_idx]: for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads() buffer.comm_grads()
...@@ -787,7 +787,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -787,7 +787,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.stage_id != 0: if self.stage_id != 0:
if ( if (
self._backward_step_count self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size == self.num_stages * self.num_model_chunks
): ):
for buffer in self._chunk_2_comm_buffers[0]: for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads() buffer.comm_grads()
...@@ -796,11 +796,10 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -796,11 +796,10 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap: if self._comm_overlap:
assert ( assert (
self._backward_step_count self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size == self.num_stages * self.num_model_chunks
), "backward step count should be equal to accumulate steps * " ), (
"virtual pp world size, but get {}, excepted result is {}".format( "backward step count should be equal to accumulate steps * virtual pp world size,"
self._backward_step_count, f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}"
self.accumulate_steps * self._virtual_pp_world_size,
) )
for _, buffers in self._chunk_2_comm_buffers.items(): for _, buffers in self._chunk_2_comm_buffers.items():
...@@ -863,7 +862,18 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -863,7 +862,18 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._forward_only = forward_only self._forward_only = forward_only
# store the number of backward steps # store the number of backward steps
self._backward_step_count = 0
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
)
per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
self._backward_step_count = (
-(per_stage_accumulate_steps - 1)
* self.num_stages
* self.num_model_chunks
)
# init some data buffers for interleave scheduler # init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)] self.input_tensors = [[] for _ in range(self.num_model_chunks)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册