diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index c9d7c71dbbb18e874a05f11eb61021d243c24794..6cb69bc73ce617e1de8abe0b3e05c3002d14152a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -41,14 +41,16 @@ class VocabParallelEmbedding(Layer): num_embeddings, embedding_dim, weight_attr=None, + mp_group=None, name=None): super(VocabParallelEmbedding, self).__init__() self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank self.origin_num_embeddings = num_embeddings self.is_mp = (self.world_size > 1) @@ -108,14 +110,15 @@ class ColumnParallelLinear(Layer): weight_attr=None, has_bias=None, gather_output=True, - name=None, - fuse_matmul_bias=False): + fuse_matmul_bias=False, + mp_group=None, + name=None): super(ColumnParallelLinear, self).__init__() self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) + ) if mp_group is None else mp_group.nranks self._name = name self.is_mp = (self.world_size > 1) @@ -197,8 +200,9 @@ class RowParallelLinear(Layer): weight_attr=None, has_bias=True, input_is_parallel=False, - name=None, - fuse_matmul_bias=False): + fuse_matmul_bias=False, + mp_group=None, + name=None): super(RowParallelLinear, self).__init__() self.in_features = in_features @@ -209,10 +213,11 @@ class RowParallelLinear(Layer): self._name = name self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank self.is_mp = (self.world_size > 1) assert in_features % self.world_size == 0, ( @@ -288,14 +293,15 @@ class RowParallelLinear(Layer): class ParallelCrossEntropy(Layer): - def __init__(self, name=None): + def __init__(self, mp_group=None, name=None): super(ParallelCrossEntropy, self).__init__() self.name = name self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) + ) if mp_group is None else mp_group self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank def forward(self, input, label): loss = paddle.distributed.collective._c_softmax_with_cross_entropy( diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 161f4d3262ab173dfa3380b2b4195a3f27579960..c3b0693d7ebd0b8f8ce4340fb4bbd7f2bb21c1cf 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -109,16 +109,19 @@ def _broadcast_data_help(data, shape, dtype, hcg): def broadcast_input_data(hcg, *inputs, **kwargs): + cur_device = paddle.get_device() for v in inputs: - if isinstance(v, core.VarBase): + if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): + v = v.cuda() if "gpu" in cur_device else v _broadcast_data_help(v, v.shape, v.dtype, hcg) else: logger.error("it doesn't support data type {}".format(type(v))) for k, v in kwargs.items(): - if isinstance(v, core.VarBase): + if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): + v = v.cuda() if "gpu" in cur_device else v _broadcast_data_help(v, v.shape, v.dtype, hcg) kwargs[k] = v else: