diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py old mode 100644 new mode 100755 index ceb1cf4e0347ec08174726e5a2d94e1a4ca8b2c7..53c617daf005e3d2e0a5400f3cc6614bcb978a42 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -940,6 +940,14 @@ class Fleet(object): distributed_model = ShardingParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: + + # NOTE (JZ-LIANG) init parameters broadcast within sharding group + # normally it should be done inside DataParallel + if self.sharding_degree > 1: + from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters + assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size( + ) + broadcast_sharding_parameters(model, self._hcg) distributed_model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py old mode 100644 new mode 100755 index e3a5947bf60fc1aa152dd1ecfd89689cc204536e..72af527896152c446ab97753cf9304febb2d2870 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -111,7 +111,7 @@ class HybridParallelOptimizer: @imperative_base.no_grad @framework.dygraph_only def step(self): - # Here should use global parameter list + if self._sharding_enable: sharding_reduce_gradients( list(self._inner_opt._parameter_list), self._hcg) @@ -131,14 +131,13 @@ class HybridParallelOptimizer: parameter_list = parameters if parameters \ else self._inner_opt._parameter_list - # Here should use global parameter list + # Here shardinng should use global parameter list if self._sharding_enable: sharding_reduce_gradients( list(self._inner_opt._parameter_list), self._hcg) if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients(list(parameter_list), self._hcg) - return self._inner_opt.minimize(loss, startup_program, parameter_list, no_grad_set) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py old mode 100644 new mode 100755 index 5ea3659bed110213341a0a9e9aef333d11a92164..db6fc964895ffcfd9d69cbe7557141c2d8586382 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -224,26 +224,33 @@ class PipelineLayer(Layer): self.get_stage_from_index(idx) for idx in shared_layers) self._dp_degree = self._topo.get_dim('data') self._mp_degree = self._topo.get_dim('model') + self._sharding_degree = self._topo.get_dim('sharding') shared_ranks = [] for dp in range(self._dp_degree): - for mp in range(self._mp_degree): - shared_ranks = [] - for s in sorted(shared_stages): - shared_ranks.append( - self._topo.get_rank_from_stage( - self.global_rank, pipe=s, data=dp, model=mp)) - - group = paddle.distributed.new_group(ranks=shared_ranks) - if self.global_rank in shared_ranks: - assert key in self.shared_layers - if key in self.shared_layers: - shared_comm[key] = { - 'ranks': shared_ranks, - 'group': group, - 'weight_attr': self.shared_weight_attrs[key], - 'layer': self.shared_layers[key], - } + for sharding in range(self._sharding_degree): + for mp in range(self._mp_degree): + shared_ranks = [] + for s in sorted(shared_stages): + shared_ranks.append( + self._topo.get_rank_from_stage( + self.global_rank, + pipe=s, + data=dp, + sharding=sharding, + model=mp)) + + group = paddle.distributed.new_group(ranks=shared_ranks) + if self.global_rank in shared_ranks: + assert key in self.shared_layers + if key in self.shared_layers: + shared_comm[key] = { + 'ranks': shared_ranks, + 'group': group, + 'weight_attr': + self.shared_weight_attrs[key], + 'layer': self.shared_layers[key], + } return shared_comm def _synchronize_shared_weights(self): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py old mode 100644 new mode 100755 index 706d64d8d35b6112f3feb270c4952a7c4276b00f..ddd4b6c6bb685737f66bcd8f6aea4e330475f90b --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -19,6 +19,7 @@ from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters +from ..utils.hybrid_parallel_util import broadcast_sharding_parameters from ..utils.log_util import logger from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from .pp_utils import p2p_communication as p2p @@ -34,6 +35,8 @@ class PipelineParallel(MetaParallelBase): super(PipelineParallel, self).__init__(layers, hcg, strategy) self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size( + ) > 1 self.total_loss = None @@ -66,6 +69,10 @@ class PipelineParallel(MetaParallelBase): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self.use_sharding_parallel: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + if self.use_data_parallel: logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py old mode 100644 new mode 100755 index 1dbf668d6e13a01c29e14b6687106683af0e9d97..171df7cf033be218848010301dd526eb66dba831 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -16,7 +16,7 @@ from paddle.fluid.dygraph.layers import Layer from .meta_parallel_base import MetaParallelBase from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_input_data -from ..utils.hybrid_parallel_util import broadcast_mp_parameters +from ..utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters from ..utils.log_util import logger __all__ = [] @@ -30,6 +30,10 @@ class TensorParallel(MetaParallelBase): logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + if self._hcg.get_sharding_parallel_world_size() > 1: + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg)