未验证 提交 2c922d63 编写于 作者: J JZ-LIANG 提交者: GitHub

[Dygraph 4D Parallel] Sharding Support MP-PP-DP Parallelism (#35580)

* sharding support dp

* sharding support mp

* sharding support pp
上级 49e243c9
...@@ -940,6 +940,14 @@ class Fleet(object): ...@@ -940,6 +940,14 @@ class Fleet(object):
distributed_model = ShardingParallel( distributed_model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy) model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
# NOTE (JZ-LIANG) init parameters broadcast within sharding group
# normally it should be done inside DataParallel
if self.sharding_degree > 1:
from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
)
broadcast_sharding_parameters(model, self._hcg)
distributed_model = paddle.DataParallel( distributed_model = paddle.DataParallel(
model, model,
comm_buffer_size=self._user_defined_strategy. comm_buffer_size=self._user_defined_strategy.
......
...@@ -111,7 +111,7 @@ class HybridParallelOptimizer: ...@@ -111,7 +111,7 @@ class HybridParallelOptimizer:
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
# Here should use global parameter list
if self._sharding_enable: if self._sharding_enable:
sharding_reduce_gradients( sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg) list(self._inner_opt._parameter_list), self._hcg)
...@@ -131,14 +131,13 @@ class HybridParallelOptimizer: ...@@ -131,14 +131,13 @@ class HybridParallelOptimizer:
parameter_list = parameters if parameters \ parameter_list = parameters if parameters \
else self._inner_opt._parameter_list else self._inner_opt._parameter_list
# Here should use global parameter list # Here shardinng should use global parameter list
if self._sharding_enable: if self._sharding_enable:
sharding_reduce_gradients( sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg) list(self._inner_opt._parameter_list), self._hcg)
if not self._use_dp_mode and self._need_dp: if not self._use_dp_mode and self._need_dp:
fused_allreduce_gradients(list(parameter_list), self._hcg) fused_allreduce_gradients(list(parameter_list), self._hcg)
return self._inner_opt.minimize(loss, startup_program, parameter_list, return self._inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set) no_grad_set)
......
...@@ -224,26 +224,33 @@ class PipelineLayer(Layer): ...@@ -224,26 +224,33 @@ class PipelineLayer(Layer):
self.get_stage_from_index(idx) for idx in shared_layers) self.get_stage_from_index(idx) for idx in shared_layers)
self._dp_degree = self._topo.get_dim('data') self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model') self._mp_degree = self._topo.get_dim('model')
self._sharding_degree = self._topo.get_dim('sharding')
shared_ranks = [] shared_ranks = []
for dp in range(self._dp_degree): for dp in range(self._dp_degree):
for mp in range(self._mp_degree): for sharding in range(self._sharding_degree):
shared_ranks = [] for mp in range(self._mp_degree):
for s in sorted(shared_stages): shared_ranks = []
shared_ranks.append( for s in sorted(shared_stages):
self._topo.get_rank_from_stage( shared_ranks.append(
self.global_rank, pipe=s, data=dp, model=mp)) self._topo.get_rank_from_stage(
self.global_rank,
group = paddle.distributed.new_group(ranks=shared_ranks) pipe=s,
if self.global_rank in shared_ranks: data=dp,
assert key in self.shared_layers sharding=sharding,
if key in self.shared_layers: model=mp))
shared_comm[key] = {
'ranks': shared_ranks, group = paddle.distributed.new_group(ranks=shared_ranks)
'group': group, if self.global_rank in shared_ranks:
'weight_attr': self.shared_weight_attrs[key], assert key in self.shared_layers
'layer': self.shared_layers[key], if key in self.shared_layers:
} shared_comm[key] = {
'ranks': shared_ranks,
'group': group,
'weight_attr':
self.shared_weight_attrs[key],
'layer': self.shared_layers[key],
}
return shared_comm return shared_comm
def _synchronize_shared_weights(self): def _synchronize_shared_weights(self):
......
...@@ -19,6 +19,7 @@ from .parallel_layers.pp_layers import PipelineLayer ...@@ -19,6 +19,7 @@ from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.log_util import logger from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
from .pp_utils import p2p_communication as p2p from .pp_utils import p2p_communication as p2p
...@@ -34,6 +35,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -34,6 +35,8 @@ class PipelineParallel(MetaParallelBase):
super(PipelineParallel, self).__init__(layers, hcg, strategy) super(PipelineParallel, self).__init__(layers, hcg, strategy)
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
) > 1
self.total_loss = None self.total_loss = None
...@@ -66,6 +69,10 @@ class PipelineParallel(MetaParallelBase): ...@@ -66,6 +69,10 @@ class PipelineParallel(MetaParallelBase):
logger.info("start broadcast mp parameters") logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg) broadcast_mp_parameters(self._layers, self._hcg)
if self.use_sharding_parallel:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
if self.use_data_parallel: if self.use_data_parallel:
logger.info("start broadcast dp parameters") logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
......
...@@ -16,7 +16,7 @@ from paddle.fluid.dygraph.layers import Layer ...@@ -16,7 +16,7 @@ from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import broadcast_input_data from ..utils.hybrid_parallel_util import broadcast_input_data
from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
from ..utils.log_util import logger from ..utils.log_util import logger
__all__ = [] __all__ = []
...@@ -30,6 +30,10 @@ class TensorParallel(MetaParallelBase): ...@@ -30,6 +30,10 @@ class TensorParallel(MetaParallelBase):
logger.info("start broadcast mp parameters") logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg) broadcast_mp_parameters(self._layers, self._hcg)
if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
logger.info("start broadcast dp parameters") logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册