未验证 提交 2c922d63 编写于 作者: J JZ-LIANG 提交者: GitHub

[Dygraph 4D Parallel] Sharding Support MP-PP-DP Parallelism (#35580)

* sharding support dp

* sharding support mp

* sharding support pp
上级 49e243c9
......@@ -940,6 +940,14 @@ class Fleet(object):
distributed_model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
# NOTE (JZ-LIANG) init parameters broadcast within sharding group
# normally it should be done inside DataParallel
if self.sharding_degree > 1:
from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
)
broadcast_sharding_parameters(model, self._hcg)
distributed_model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
......
......@@ -111,7 +111,7 @@ class HybridParallelOptimizer:
@imperative_base.no_grad
@framework.dygraph_only
def step(self):
# Here should use global parameter list
if self._sharding_enable:
sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
......@@ -131,14 +131,13 @@ class HybridParallelOptimizer:
parameter_list = parameters if parameters \
else self._inner_opt._parameter_list
# Here should use global parameter list
# Here shardinng should use global parameter list
if self._sharding_enable:
sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
if not self._use_dp_mode and self._need_dp:
fused_allreduce_gradients(list(parameter_list), self._hcg)
return self._inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set)
......
......@@ -224,15 +224,21 @@ class PipelineLayer(Layer):
self.get_stage_from_index(idx) for idx in shared_layers)
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')
self._sharding_degree = self._topo.get_dim('sharding')
shared_ranks = []
for dp in range(self._dp_degree):
for sharding in range(self._sharding_degree):
for mp in range(self._mp_degree):
shared_ranks = []
for s in sorted(shared_stages):
shared_ranks.append(
self._topo.get_rank_from_stage(
self.global_rank, pipe=s, data=dp, model=mp))
self.global_rank,
pipe=s,
data=dp,
sharding=sharding,
model=mp))
group = paddle.distributed.new_group(ranks=shared_ranks)
if self.global_rank in shared_ranks:
......@@ -241,7 +247,8 @@ class PipelineLayer(Layer):
shared_comm[key] = {
'ranks': shared_ranks,
'group': group,
'weight_attr': self.shared_weight_attrs[key],
'weight_attr':
self.shared_weight_attrs[key],
'layer': self.shared_layers[key],
}
return shared_comm
......
......@@ -19,6 +19,7 @@ from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
from .pp_utils import p2p_communication as p2p
......@@ -34,6 +35,8 @@ class PipelineParallel(MetaParallelBase):
super(PipelineParallel, self).__init__(layers, hcg, strategy)
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
) > 1
self.total_loss = None
......@@ -66,6 +69,10 @@ class PipelineParallel(MetaParallelBase):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)
if self.use_sharding_parallel:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
if self.use_data_parallel:
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
......
......@@ -16,7 +16,7 @@ from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import broadcast_input_data
from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
from ..utils.log_util import logger
__all__ = []
......@@ -30,6 +30,10 @@ class TensorParallel(MetaParallelBase):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)
if self._hcg.get_sharding_parallel_world_size() > 1:
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册