未验证 提交 6fde2056 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

精简 virtual pipeline 调度逻辑 (#54003)

* unify code

* remove useless code

* polish

* python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

* polish

* polish
上级 3794d171
......@@ -557,6 +557,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert (
framework.in_dynamic_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode"
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
# setup for interleave scheduler
self.num_model_chunks = layers.get_num_virtual_stages()
self.model_chunks = layers.get_model_chunks()
......@@ -583,12 +586,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert hasattr(self, 'output_tensors')
if not self._forward_only:
assert hasattr(self, 'output_tensor_grads')
if self.is_pipeline_first_stage():
if len(self.input_tensors[virtual_pp_rank]) == len(
self.output_tensors[virtual_pp_rank]
):
self.input_tensors[virtual_pp_rank].append(None)
assert len(self.input_tensors[virtual_pp_rank]) == (
len(self.output_tensors[virtual_pp_rank]) + 1
)
input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
self.output_tensors[virtual_pp_rank].append(output_tensor)
......@@ -609,9 +609,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert hasattr(self, 'output_tensors')
assert hasattr(self, 'output_tensor_grads')
if self.is_pipeline_last_stage():
if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
self.output_tensor_grads[virtual_pp_rank].append(None)
assert (
len(self.output_tensor_grads[virtual_pp_rank]) == 1
), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"
assert len(self.input_tensors[virtual_pp_rank]) > 0
assert len(self.output_tensors[virtual_pp_rank]) > 0
input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
......@@ -646,18 +649,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]
num_steps = self.accumulate_steps * self.num_model_chunks
all_startup_steps = False
if forward_only:
# If only forward, since there is no backward during running, all steps are startup steps
startup_steps = num_steps
else:
if self.accumulate_steps == self.num_stages:
startup_steps = num_steps
all_startup_steps = True
else:
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps)
# actually startup_steps is calculated from two number:
# first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
# end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
# startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps)
steady_steps = num_steps - startup_steps
......@@ -687,11 +689,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.is_pipeline_last_stage():
output_tensor = None
if (
micro_step == (startup_steps - 1)
and not forward_only
and not all_startup_steps
):
if micro_step == (startup_steps - 1) and not forward_only:
input_tensor_grad = None
recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True):
......@@ -707,6 +705,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev=recv_prev,
recv_next=recv_next,
)
# output_tensor_grad is not none if recv_next
# append output_tensor_grad no matter none or not
self.output_tensor_grads[self.num_model_chunks - 1].append(
output_tensor_grad
)
......@@ -714,6 +714,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor = p2p.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev
)
# append input_tensor no matter none or not
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
# run 1f1b steady steps
......@@ -752,18 +753,14 @@ class PipelineParallelWithInterleave(PipelineParallel):
# determine whether to recv input tensor from upstream
recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True):
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id - (self.num_stages - 1), forward=True
)
if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
# first pp stage and first virtual stage
recv_prev = False
next_forward_virtual_pp_rank += 1
else:
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
)
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
)
if self.is_pipeline_first_stage(ignore_virtual=True) and (
next_forward_virtual_pp_rank == 0
):
# first pp stage and first virtual stage
recv_prev = False
# last iteration doesn't need recv from upstream
if micro_step == (steady_steps - 1):
......@@ -771,19 +768,14 @@ class PipelineParallelWithInterleave(PipelineParallel):
# determine whether to recv grad from downstream
recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True):
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id - (self.num_stages - 1),
forward=False,
)
if next_backward_virtual_pp_rank == 0:
# last pp stage and last virtual stage
recv_next = False
next_backward_virtual_pp_rank -= 1
else:
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
)
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
)
if self.is_pipeline_last_stage(ignore_virtual=True) and (
next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
):
# last pp stage and last virtual stage
recv_next = False
(
input_tensor,
......@@ -794,25 +786,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev=recv_prev,
recv_next=recv_next,
)
if recv_prev:
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
)
if recv_next:
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad
)
# append input_tensor no matter none or not
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
)
# append output_tensor_grad no matter none or not
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad
)
# remaining backward steps
if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(
self.is_pipeline_last_stage(), sync_recv=False
)
)
for micro_step in range(steady_steps, num_steps):
# cooldown loop
input_tensor_grad = self._backward_step_helper(micro_step)
......@@ -829,7 +813,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if micro_step == (num_steps - 1):
recv_next = False
# append output_tensor_grad no matter none or not
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next
......
......@@ -60,6 +60,7 @@ class TestPipeLayerAPI(unittest.TestCase):
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {"accumulate_steps": 2}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册