未验证 提交 6fde2056 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

精简 virtual pipeline 调度逻辑 (#54003)

* unify code

* remove useless code

* polish

* python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

* polish

* polish
上级 3794d171
...@@ -557,6 +557,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -557,6 +557,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert ( assert (
framework.in_dynamic_mode() framework.in_dynamic_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode" ), "virtual pipeline stage with interleave only support eager dygraph mode"
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
# setup for interleave scheduler # setup for interleave scheduler
self.num_model_chunks = layers.get_num_virtual_stages() self.num_model_chunks = layers.get_num_virtual_stages()
self.model_chunks = layers.get_model_chunks() self.model_chunks = layers.get_model_chunks()
...@@ -583,12 +586,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -583,12 +586,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert hasattr(self, 'output_tensors') assert hasattr(self, 'output_tensors')
if not self._forward_only: if not self._forward_only:
assert hasattr(self, 'output_tensor_grads') assert hasattr(self, 'output_tensor_grads')
assert len(self.input_tensors[virtual_pp_rank]) == (
if self.is_pipeline_first_stage(): len(self.output_tensors[virtual_pp_rank]) + 1
if len(self.input_tensors[virtual_pp_rank]) == len( )
self.output_tensors[virtual_pp_rank]
):
self.input_tensors[virtual_pp_rank].append(None)
input_tensor = self.input_tensors[virtual_pp_rank][-1] input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank) output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
self.output_tensors[virtual_pp_rank].append(output_tensor) self.output_tensors[virtual_pp_rank].append(output_tensor)
...@@ -609,9 +609,12 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -609,9 +609,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert hasattr(self, 'output_tensors') assert hasattr(self, 'output_tensors')
assert hasattr(self, 'output_tensor_grads') assert hasattr(self, 'output_tensor_grads')
if self.is_pipeline_last_stage(): assert (
if len(self.output_tensor_grads[virtual_pp_rank]) == 0: len(self.output_tensor_grads[virtual_pp_rank]) == 1
self.output_tensor_grads[virtual_pp_rank].append(None) ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"
assert len(self.input_tensors[virtual_pp_rank]) > 0
assert len(self.output_tensors[virtual_pp_rank]) > 0
input_tensor = self.input_tensors[virtual_pp_rank].pop(0) input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
output_tensor = self.output_tensors[virtual_pp_rank].pop(0) output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
...@@ -646,18 +649,17 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -646,18 +649,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]
num_steps = self.accumulate_steps * self.num_model_chunks num_steps = self.accumulate_steps * self.num_model_chunks
all_startup_steps = False
if forward_only: if forward_only:
# If only forward, since there is no backward during running, all steps are startup steps # If only forward, since there is no backward during running, all steps are startup steps
startup_steps = num_steps startup_steps = num_steps
else: else:
if self.accumulate_steps == self.num_stages: # actually startup_steps is calculated from two number:
startup_steps = num_steps # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
all_startup_steps = True # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
else: # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
startup_steps = (self.num_stages - self.stage_id - 1) * 2 startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps) startup_steps = min(startup_steps, num_steps)
steady_steps = num_steps - startup_steps steady_steps = num_steps - startup_steps
...@@ -687,11 +689,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -687,11 +689,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.is_pipeline_last_stage(): if self.is_pipeline_last_stage():
output_tensor = None output_tensor = None
if ( if micro_step == (startup_steps - 1) and not forward_only:
micro_step == (startup_steps - 1)
and not forward_only
and not all_startup_steps
):
input_tensor_grad = None input_tensor_grad = None
recv_next = True recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True): if self.is_pipeline_last_stage(ignore_virtual=True):
...@@ -707,6 +705,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -707,6 +705,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
) )
# output_tensor_grad is not none if recv_next
# append output_tensor_grad no matter none or not
self.output_tensor_grads[self.num_model_chunks - 1].append( self.output_tensor_grads[self.num_model_chunks - 1].append(
output_tensor_grad output_tensor_grad
) )
...@@ -714,6 +714,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -714,6 +714,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor = p2p.send_forward_recv_forward( input_tensor = p2p.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev output_tensor, recv_prev=recv_prev
) )
# append input_tensor no matter none or not
self.input_tensors[next_virtual_pp_rank].append(input_tensor) self.input_tensors[next_virtual_pp_rank].append(input_tensor)
# run 1f1b steady steps # run 1f1b steady steps
...@@ -752,18 +753,14 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -752,18 +753,14 @@ class PipelineParallelWithInterleave(PipelineParallel):
# determine whether to recv input tensor from upstream # determine whether to recv input tensor from upstream
recv_prev = True recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True): next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
next_forward_virtual_pp_rank = self._get_virtual_pp_rank( forward_micro_step_id + 1, forward=True
forward_micro_step_id - (self.num_stages - 1), forward=True )
) if self.is_pipeline_first_stage(ignore_virtual=True) and (
if next_forward_virtual_pp_rank == (self.num_model_chunks - 1): next_forward_virtual_pp_rank == 0
# first pp stage and first virtual stage ):
recv_prev = False # first pp stage and first virtual stage
next_forward_virtual_pp_rank += 1 recv_prev = False
else:
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
)
# last iteration doesn't need recv from upstream # last iteration doesn't need recv from upstream
if micro_step == (steady_steps - 1): if micro_step == (steady_steps - 1):
...@@ -771,19 +768,14 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -771,19 +768,14 @@ class PipelineParallelWithInterleave(PipelineParallel):
# determine whether to recv grad from downstream # determine whether to recv grad from downstream
recv_next = True recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True): next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
next_backward_virtual_pp_rank = self._get_virtual_pp_rank( backward_micro_step_id + 1, forward=False
backward_micro_step_id - (self.num_stages - 1), )
forward=False, if self.is_pipeline_last_stage(ignore_virtual=True) and (
) next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
if next_backward_virtual_pp_rank == 0: ):
# last pp stage and last virtual stage # last pp stage and last virtual stage
recv_next = False recv_next = False
next_backward_virtual_pp_rank -= 1
else:
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
)
( (
input_tensor, input_tensor,
...@@ -794,25 +786,17 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -794,25 +786,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
) )
# append input_tensor no matter none or not
if recv_prev: self.input_tensors[next_forward_virtual_pp_rank].append(
self.input_tensors[next_forward_virtual_pp_rank].append( input_tensor
input_tensor )
) # append output_tensor_grad no matter none or not
if recv_next: self.output_tensor_grads[next_backward_virtual_pp_rank].append(
self.output_tensor_grads[next_backward_virtual_pp_rank].append( output_tensor_grad
output_tensor_grad )
)
# remaining backward steps # remaining backward steps
if not forward_only: if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(
self.is_pipeline_last_stage(), sync_recv=False
)
)
for micro_step in range(steady_steps, num_steps): for micro_step in range(steady_steps, num_steps):
# cooldown loop # cooldown loop
input_tensor_grad = self._backward_step_helper(micro_step) input_tensor_grad = self._backward_step_helper(micro_step)
...@@ -829,7 +813,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -829,7 +813,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if micro_step == (num_steps - 1): if micro_step == (num_steps - 1):
recv_next = False recv_next = False
# append output_tensor_grad no matter none or not
self.output_tensor_grads[next_backward_virtual_pp_rank].append( self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward( p2p.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next input_tensor_grad, recv_next=recv_next
......
...@@ -60,6 +60,7 @@ class TestPipeLayerAPI(unittest.TestCase): ...@@ -60,6 +60,7 @@ class TestPipeLayerAPI(unittest.TestCase):
"mp_degree": 1, "mp_degree": 1,
"pp_degree": self.pipeline_parallel_size, "pp_degree": self.pipeline_parallel_size,
} }
strategy.pipeline_configs = {"accumulate_steps": 2}
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index() self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册