未验证 提交 f275ad2b 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed] Support dp/sharding overlap in virtual pp (#55651)

* Add virtual pp and dp overlap

* add sharding/dp overlap

* add dp/vpp overlap

* fix code

* fix log
上级 8520a5b3
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
import os
import sys
from collections import defaultdict
import paddle
from paddle import framework
......@@ -181,7 +183,7 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = []
self._chunk_2_comm_buffers = defaultdict(list)
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)
......@@ -255,7 +257,9 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
def register_allreduce_overlap_hook(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
......@@ -272,7 +276,7 @@ class PipelineParallel(MetaParallelBase):
else HOOK_ACTION.REDUCE
)
for model in models:
for chunk_idx, model in enumerate(models):
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
......@@ -301,12 +305,12 @@ class PipelineParallel(MetaParallelBase):
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
var_groups = assign_group_by_size(parameter_list)
var_groups = assign_group_by_size(parameter_list, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._comm_buffers.append(buffer)
self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
......@@ -402,9 +406,12 @@ class PipelineParallel(MetaParallelBase):
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
assert (
len(self._chunk_2_comm_buffers) > 0
), "comm buffers should be created"
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_and_split_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
......@@ -445,7 +452,7 @@ class PipelineParallel(MetaParallelBase):
self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)
......@@ -766,6 +773,40 @@ class PipelineParallelWithInterleave(PipelineParallel):
return output_tensor
def _overlap_comm_grads(self):
if self._comm_overlap:
self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.accumulate_steps == 0:
chunk_idx = self._virtual_pp_world_size - (
sync_step // self.accumulate_steps
)
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads()
if self.stage_id != 0:
if (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
):
for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads()
def _sync_overlap_grads(self):
if self._comm_overlap:
assert (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
), "backward step count should be equal to accumulate steps * "
"virtual pp world size, but get {}, excepted result is {}".format(
self._backward_step_count,
self.accumulate_steps * self._virtual_pp_world_size,
)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_and_split_grads()
def _backward_step_helper(self, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
self.set_virtual_pipeline_rank(virtual_pp_rank)
......@@ -786,8 +827,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor, output_tensor, output_tensor_grad
)
self._overlap_comm_grads()
return input_tensor_grad
def bw_hook_func(self, buffer, param):
# For pipeline with interleave, we need to add grad to buffer without communication.
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param, use_comm=False)
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
super().register_allreduce_overlap_hook(
model, comm_group, acc_steps, dp, group_size=sys.maxsize
)
def forward_backward_pipeline(
self, data, scaler, forward_only=False, compute_loss=True
):
......@@ -805,6 +862,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.micro_batch_id = 0
self._forward_only = forward_only
# store the number of backward steps
self._backward_step_count = 0
# init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensors = [[] for _ in range(self.num_model_chunks)]
......@@ -1011,10 +1071,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
)
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
self._sync_overlap_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
......
......@@ -218,7 +218,7 @@ class FusedCommBuffer:
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
def add_grad(self, param, use_comm=True):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
......@@ -239,12 +239,17 @@ class FusedCommBuffer:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
if self._all_params_checked_in and use_comm:
self.comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
def comm_grads(self):
assert self._all_params_checked_in, (
"Not all params checked in."
"Parameter number: {}, Check-in number: {}".format(
len(self._params), self._params_checked_in
)
)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
......@@ -263,9 +268,8 @@ class FusedCommBuffer:
@imperative_base.no_grad
def scale_and_split_grads(self):
assert self._task is not None
assert self._task is not None, "Task is not initialized. "
self._task.wait()
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册