From 2c922d63f275d09fe62b8db4577ad9d1eb484834 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Fri, 10 Sep 2021 18:42:21 +0800 Subject: [PATCH] [Dygraph 4D Parallel] Sharding Support MP-PP-DP Parallelism (#35580) * sharding support dp * sharding support mp * sharding support pp --- .../distributed/fleet/base/fleet_base.py | 8 ++++ .../hybrid_parallel_optimizer.py | 5 +-- .../parallel_layers/pp_layers.py | 41 +++++++++++-------- .../fleet/meta_parallel/pipeline_parallel.py | 7 ++++ .../fleet/meta_parallel/tensor_parallel.py | 6 ++- 5 files changed, 46 insertions(+), 21 deletions(-) mode change 100644 => 100755 python/paddle/distributed/fleet/base/fleet_base.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py old mode 100644 new mode 100755 index ceb1cf4e034..53c617daf00 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -940,6 +940,14 @@ class Fleet(object): distributed_model = ShardingParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: + + # NOTE (JZ-LIANG) init parameters broadcast within sharding group + # normally it should be done inside DataParallel + if self.sharding_degree > 1: + from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters + assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size( + ) + broadcast_sharding_parameters(model, self._hcg) distributed_model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py old mode 100644 new mode 100755 index e3a5947bf60..72af5278961 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -111,7 +111,7 @@ class HybridParallelOptimizer: @imperative_base.no_grad @framework.dygraph_only def step(self): - # Here should use global parameter list + if self._sharding_enable: sharding_reduce_gradients( list(self._inner_opt._parameter_list), self._hcg) @@ -131,14 +131,13 @@ class HybridParallelOptimizer: parameter_list = parameters if parameters \ else self._inner_opt._parameter_list - # Here should use global parameter list + # Here shardinng should use global parameter list if self._sharding_enable: sharding_reduce_gradients( list(self._inner_opt._parameter_list), self._hcg) if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients(list(parameter_list), self._hcg) - return self._inner_opt.minimize(loss, startup_program, parameter_list, no_grad_set) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py old mode 100644 new mode 100755 index 5ea3659bed1..db6fc964895 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -224,26 +224,33 @@ class PipelineLayer(Layer): self.get_stage_from_index(idx) for idx in shared_layers) self._dp_degree = self._topo.get_dim('data') self._mp_degree = self._topo.get_dim('model') + self._sharding_degree = self._topo.get_dim('sharding') shared_ranks = [] for dp in range(self._dp_degree): - for mp in range(self._mp_degree): - shared_ranks = [] - for s in sorted(shared_stages): - shared_ranks.append( - self._topo.get_rank_from_stage( - self.global_rank, pipe=s, data=dp, model=mp)) - - group = paddle.distributed.new_group(ranks=shared_ranks) - if self.global_rank in shared_ranks: - assert key in self.shared_layers - if key in self.shared_layers: - shared_comm[key] = { - 'ranks': shared_ranks, - 'group': group, - 'weight_attr': self.shared_weight_attrs[key], - 'layer': self.shared_layers[key], - } + for sharding in range(self._sharding_degree): + for mp in range(self._mp_degree): + shared_ranks = [] + for s in sorted(shared_stages): + shared_ranks.append( + self._topo.get_rank_from_stage( + self.global_rank, + pipe=s, + data=dp, + sharding=sharding, + model=mp)) + + group = paddle.distributed.new_group(ranks=shared_ranks) + if self.global_rank in shared_ranks: + assert key in self.shared_layers + if key in self.shared_layers: + shared_comm[key] = { + 'ranks': shared_ranks, + 'group': group, + 'weight_attr': + self.shared_weight_attrs[key], + 'layer': self.shared_layers[key], + } return shared_comm def _synchronize_shared_weights(self): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py old mode 100644 new mode 100755 index 706d64d8d35..ddd4b6c6bb6 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -19,6 +19,7 @@ from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters +from ..utils.hybrid_parallel_util import broadcast_sharding_parameters from ..utils.log_util import logger from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from .pp_utils import p2p_communication as p2p @@ -34,6 +35,8 @@ class PipelineParallel(MetaParallelBase): super(PipelineParallel, self).__init__(layers, hcg, strategy) self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size( + ) > 1 self.total_loss = None @@ -66,6 +69,10 @@ class PipelineParallel(MetaParallelBase): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self.use_sharding_parallel: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + if self.use_data_parallel: logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py old mode 100644 new mode 100755 index 1dbf668d6e1..171df7cf033 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -16,7 +16,7 @@ from paddle.fluid.dygraph.layers import Layer from .meta_parallel_base import MetaParallelBase from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_input_data -from ..utils.hybrid_parallel_util import broadcast_mp_parameters +from ..utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters from ..utils.log_util import logger __all__ = [] @@ -30,6 +30,10 @@ class TensorParallel(MetaParallelBase): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self._hcg.get_sharding_parallel_world_size() > 1: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) -- GitLab