未验证 提交 82dd6b15 编写于 作者: Y Yuang Liu 提交者: GitHub

[Hybrid Performance] Sharding stage 1 PP/VP overlap (#54312)

* sharding pp overlap

* bug fix

* update

* rename function

* update code logic
上级 ed604569
......@@ -62,6 +62,7 @@ message PpConfig {
optional bool dp_comm_overlap = 1 [ default = false ];
optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
}
message HybridConfig {
......
......@@ -25,7 +25,7 @@ from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
from .pp_utils import p2p_communication as p2p
from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = []
......@@ -62,6 +62,7 @@ class PipelineParallel(MetaParallelBase):
self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group()
self.dp_group = self._hcg.get_data_parallel_group()
self.sharding_group = self._hcg.get_sharding_parallel_group()
self._virtual_pp_world_size = None
self._virtual_pp_rank = None
......@@ -76,14 +77,28 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap
self._sharding_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer
self._dp_comm_buffers = []
if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1
if self._sharding_comm_overlap:
assert self.use_sharding_parallel and self.num_stages > 1
assert not (
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = []
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)
if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
......@@ -121,7 +136,7 @@ class PipelineParallel(MetaParallelBase):
if self._dp_comm_overlap:
self.register_allreduce_overlap_hook(
self._layers, self.dp_group, self.accumulate_steps
self._layers, self.dp_group, self.accumulate_steps, True
)
def is_pipeline_first_stage(self, ignore_virtual=False):
......@@ -153,12 +168,21 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
models = [model]
if not dp:
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE
fused_parameter_group = {}
for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
......@@ -168,12 +192,29 @@ class PipelineParallel(MetaParallelBase):
if len(parameter_list) < 1:
return
if dp:
fused_parameter_group[-1] = parameter_list
else:
# Sort parameters for sharding, since they have different dst rank
for p in parameter_list:
assert p.name in _param2rank
dst_rank = _param2rank[p.name]
if dst_rank in fused_parameter_group:
fused_parameter_group[dst_rank].append(p)
else:
fused_parameter_group[dst_rank] = [p]
for dst in fused_parameter_group:
parameter_list = fused_parameter_group[dst]
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedAllReduceBuffer(
group_idx, parameters, comm_group, acc_steps
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._dp_comm_buffers.append(buffer)
self._comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
......@@ -263,9 +304,9 @@ class PipelineParallel(MetaParallelBase):
)
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
if self._dp_comm_overlap:
assert len(self._dp_comm_buffers) > 0
for buffer in self._dp_comm_buffers:
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
if self._enable_timer:
......@@ -307,6 +348,11 @@ class PipelineParallel(MetaParallelBase):
self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)
return data
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
......@@ -834,9 +880,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
)
if self._dp_comm_overlap:
assert len(self._dp_comm_buffers) > 0
for buffer in self._dp_comm_buffers:
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
if self._enable_timer:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -24,6 +24,12 @@ from paddle.framework import base as imperative_base
__all__ = []
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
FLOAT_TYPE_DICT = {
paddle.float16: "float16",
paddle.float32: "float32",
......@@ -114,8 +120,16 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
)
class FusedAllReduceBuffer:
def __init__(self, id, params, comm_group, acc_steps=1):
class FusedCommBuffer:
def __init__(
self,
id,
params,
comm_group,
acc_steps=1,
act=None,
dst=-1,
):
self._id = id
self._params = params
self._acc_steps = acc_steps
......@@ -127,6 +141,17 @@ class FusedAllReduceBuffer:
self._params_checked_in = 0
self._coalesced_grads_and_grad_vars = []
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict()
def _init_step_dict(self):
......@@ -164,10 +189,10 @@ class FusedAllReduceBuffer:
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._fused_allreduce_grads()
self._fused_comm_grads()
@imperative_base.no_grad
def _fused_allreduce_grads(self):
def _fused_comm_grads(self):
assert self._all_params_checked_in
flattened_vars = []
g_var_shapes = []
......@@ -184,11 +209,18 @@ class FusedAllReduceBuffer:
)
for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars:
self._tasks.append(
paddle.distributed.all_reduce(
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
coalesced_grad, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
coalesced_grad,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._tasks.append(task)
@imperative_base.no_grad
def scale_and_split_grads(self):
......
......@@ -71,6 +71,11 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None):
].dp_comm_overlap:
hp_optim._dp_enable = False
if fleet_env._user_defined_strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap:
hp_optim._sharding_enable = False
return hp_optim
else:
return HeterParallelOptimizer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册