From 82dd6b15773e9b5c57a0fe3ed32c88879ecc4c3e Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 5 Jun 2023 17:26:09 +0800 Subject: [PATCH] [Hybrid Performance] Sharding stage 1 PP/VP overlap (#54312) * sharding pp overlap * bug fix * update * rename function * update code logic --- .../framework/distributed_strategy.proto | 1 + .../fleet/meta_parallel/pipeline_parallel.py | 84 ++++++++++++++----- .../fleet/meta_parallel/pp_utils/utils.py | 48 +++++++++-- python/paddle/distributed/fleet/optimizer.py | 5 ++ 4 files changed, 111 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 4082bf2ec55..85bafbef2b6 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -62,6 +62,7 @@ message PpConfig { optional bool dp_comm_overlap = 1 [ default = false ]; optional bool delay_scale_loss = 2 [ default = false ]; optional bool enable_timer = 3 [ default = false ]; + optional bool sharding_comm_overlap = 4 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 49902a44987..c08ce5a2549 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -25,7 +25,7 @@ from ..utils.log_util import logger from .meta_parallel_base import MetaParallelBase from .parallel_layers.pp_layers import PipelineLayer from .pp_utils import p2p_communication as p2p -from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size +from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size __all__ = [] @@ -62,6 +62,7 @@ class PipelineParallel(MetaParallelBase): self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() self.dp_group = self._hcg.get_data_parallel_group() + self.sharding_group = self._hcg.get_sharding_parallel_group() self._virtual_pp_world_size = None self._virtual_pp_rank = None @@ -76,14 +77,28 @@ class PipelineParallel(MetaParallelBase): self._dp_comm_overlap = self._strategy.hybrid_configs[ "pp_configs" ].dp_comm_overlap + self._sharding_comm_overlap = self._strategy.hybrid_configs[ + "pp_configs" + ].sharding_comm_overlap self._enable_timer = self._strategy.hybrid_configs[ "pp_configs" ].enable_timer - self._dp_comm_buffers = [] if self._dp_comm_overlap: assert self.use_data_parallel and self.num_stages > 1 + if self._sharding_comm_overlap: + assert self.use_sharding_parallel and self.num_stages > 1 + + assert not ( + self._dp_comm_overlap and self._sharding_comm_overlap + ), "Cannot use dp pp overlap and sharding pp overlap at the same time." + + self._comm_buffers = [] + self._comm_overlap = ( + self._dp_comm_overlap or self._sharding_comm_overlap + ) + if self._enable_timer: if not timer.is_timer_initialized(): timer.set_timers() @@ -121,7 +136,7 @@ class PipelineParallel(MetaParallelBase): if self._dp_comm_overlap: self.register_allreduce_overlap_hook( - self._layers, self.dp_group, self.accumulate_steps + self._layers, self.dp_group, self.accumulate_steps, True ) def is_pipeline_first_stage(self, ignore_virtual=False): @@ -153,12 +168,21 @@ class PipelineParallel(MetaParallelBase): return fused_allreduce - def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): + def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp): if model.get_num_virtual_stages() > 1: models = model.get_model_chunks() else: models = [model] + if not dp: + assert hasattr(self, "optimizer") + assert hasattr(self.optimizer, "_param2rank") + _param2rank = self.optimizer._param2rank + + act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE + + fused_parameter_group = {} + for model in models: # For virtual pipeline. Will separate parameters in different chunk into # different groups to get the best performance. @@ -168,16 +192,33 @@ class PipelineParallel(MetaParallelBase): if len(parameter_list) < 1: return - var_groups = assign_group_by_size(parameter_list) - for group_idx, parameters in var_groups.items(): - buffer = FusedAllReduceBuffer( - group_idx, parameters, comm_group, acc_steps - ) - self._dp_comm_buffers.append(buffer) - for param in parameters: - param._register_backward_hook( - self.bw_hook_func(buffer, param) + if dp: + fused_parameter_group[-1] = parameter_list + else: + # Sort parameters for sharding, since they have different dst rank + for p in parameter_list: + assert p.name in _param2rank + dst_rank = _param2rank[p.name] + if dst_rank in fused_parameter_group: + fused_parameter_group[dst_rank].append(p) + else: + fused_parameter_group[dst_rank] = [p] + + for dst in fused_parameter_group: + parameter_list = fused_parameter_group[dst] + if not dp: + # parse the relative dst rank to absolute dst rank for sharding + dst = comm_group.ranks[dst] + var_groups = assign_group_by_size(parameter_list) + for group_idx, parameters in var_groups.items(): + buffer = FusedCommBuffer( + group_idx, parameters, comm_group, acc_steps, act, dst ) + self._comm_buffers.append(buffer) + for param in parameters: + param._register_backward_hook( + self.bw_hook_func(buffer, param) + ) def timer_printer(self): if not self._enable_timer: @@ -263,9 +304,9 @@ class PipelineParallel(MetaParallelBase): ) p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) - if self._dp_comm_overlap: - assert len(self._dp_comm_buffers) > 0 - for buffer in self._dp_comm_buffers: + if self._comm_overlap: + assert len(self._comm_buffers) > 0 + for buffer in self._comm_buffers: buffer.scale_and_split_grads() if self._enable_timer: @@ -307,6 +348,11 @@ class PipelineParallel(MetaParallelBase): self._layers.train() + if self._sharding_comm_overlap and len(self._comm_buffers) == 0: + self.register_allreduce_overlap_hook( + self._layers, self.sharding_group, self.accumulate_steps, False + ) + return data def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): @@ -834,9 +880,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ) ) - if self._dp_comm_overlap: - assert len(self._dp_comm_buffers) > 0 - for buffer in self._dp_comm_buffers: + if self._comm_overlap: + assert len(self._comm_buffers) > 0 + for buffer in self._comm_buffers: buffer.scale_and_split_grads() if self._enable_timer: diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index b9967ca202c..8ae20c91bb4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,12 @@ from paddle.framework import base as imperative_base __all__ = [] + +class HOOK_ACTION: + ALL_REDUCE = 0 + REDUCE = 1 + + FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", @@ -114,8 +120,16 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ) -class FusedAllReduceBuffer: - def __init__(self, id, params, comm_group, acc_steps=1): +class FusedCommBuffer: + def __init__( + self, + id, + params, + comm_group, + acc_steps=1, + act=None, + dst=-1, + ): self._id = id self._params = params self._acc_steps = acc_steps @@ -127,6 +141,17 @@ class FusedAllReduceBuffer: self._params_checked_in = 0 self._coalesced_grads_and_grad_vars = [] + self._act = act + if self._act == HOOK_ACTION.ALL_REDUCE: + assert dst == -1 + elif self._act == HOOK_ACTION.REDUCE: + assert dst != -1 + else: + raise ValueError( + "The act should be allreudce for dp or reduce for sharding." + ) + self._dst = dst + self._init_step_dict() def _init_step_dict(self): @@ -164,10 +189,10 @@ class FusedAllReduceBuffer: self._params_step_dict.pop(param.name) if self._all_params_checked_in: - self._fused_allreduce_grads() + self._fused_comm_grads() @imperative_base.no_grad - def _fused_allreduce_grads(self): + def _fused_comm_grads(self): assert self._all_params_checked_in flattened_vars = [] g_var_shapes = [] @@ -184,11 +209,18 @@ class FusedAllReduceBuffer: ) for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: - self._tasks.append( - paddle.distributed.all_reduce( + if self._act == HOOK_ACTION.ALL_REDUCE: + task = paddle.distributed.all_reduce( coalesced_grad, group=self._comm_group, sync_op=False ) - ) + elif self._act == HOOK_ACTION.REDUCE: + task = paddle.distributed.reduce( + coalesced_grad, + dst=self._dst, + group=self._comm_group, + sync_op=False, + ) + self._tasks.append(task) @imperative_base.no_grad def scale_and_split_grads(self): diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index 6fc414d0a65..9e693e670f3 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -71,6 +71,11 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): ].dp_comm_overlap: hp_optim._dp_enable = False + if fleet_env._user_defined_strategy.hybrid_configs[ + "pp_configs" + ].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim else: return HeterParallelOptimizer( -- GitLab