From c17e6af83cb7d3ef85613a57b6b3a9eba7e22ce5 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 16 Aug 2022 14:21:43 +0800 Subject: [PATCH] [Fleet] Reconstruct of Fleet API in Dygraph Mode (#44922) * reconstruct_of_fleet_api * update --- .../parallel_layers/mp_layers.py | 38 +++++++++++-------- .../fleet/utils/hybrid_parallel_util.py | 7 +++- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index c9d7c71dbbb..6cb69bc73ce 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -41,14 +41,16 @@ class VocabParallelEmbedding(Layer): num_embeddings, embedding_dim, weight_attr=None, + mp_group=None, name=None): super(VocabParallelEmbedding, self).__init__() self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank self.origin_num_embeddings = num_embeddings self.is_mp = (self.world_size > 1) @@ -108,14 +110,15 @@ class ColumnParallelLinear(Layer): weight_attr=None, has_bias=None, gather_output=True, - name=None, - fuse_matmul_bias=False): + fuse_matmul_bias=False, + mp_group=None, + name=None): super(ColumnParallelLinear, self).__init__() self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) + ) if mp_group is None else mp_group.nranks self._name = name self.is_mp = (self.world_size > 1) @@ -197,8 +200,9 @@ class RowParallelLinear(Layer): weight_attr=None, has_bias=True, input_is_parallel=False, - name=None, - fuse_matmul_bias=False): + fuse_matmul_bias=False, + mp_group=None, + name=None): super(RowParallelLinear, self).__init__() self.in_features = in_features @@ -209,10 +213,11 @@ class RowParallelLinear(Layer): self._name = name self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank self.is_mp = (self.world_size > 1) assert in_features % self.world_size == 0, ( @@ -288,14 +293,15 @@ class RowParallelLinear(Layer): class ParallelCrossEntropy(Layer): - def __init__(self, name=None): + def __init__(self, mp_group=None, name=None): super(ParallelCrossEntropy, self).__init__() self.name = name self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank def forward(self, input, label): loss = paddle.distributed.collective._c_softmax_with_cross_entropy( diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 161f4d3262a..c3b0693d7eb 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -109,16 +109,19 @@ def _broadcast_data_help(data, shape, dtype, hcg): def broadcast_input_data(hcg, *inputs, **kwargs): + cur_device = paddle.get_device() for v in inputs: - if isinstance(v, core.VarBase): + if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): + v = v.cuda() if "gpu" in cur_device else v _broadcast_data_help(v, v.shape, v.dtype, hcg) else: logger.error("it doesn't support data type {}".format(type(v))) for k, v in kwargs.items(): - if isinstance(v, core.VarBase): + if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): + v = v.cuda() if "gpu" in cur_device else v _broadcast_data_help(v, v.shape, v.dtype, hcg) kwargs[k] = v else: -- GitLab