From fa97e5bad5829a079749f3947bf93af0673c8d5d Mon Sep 17 00:00:00 2001
From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com>
Date: Fri, 16 Sep 2022 10:53:58 +0800
Subject: [PATCH] refactor mp. (#45803)

* refactor mp.

* update setup.py.

* update mp_layers.py for compatibility.

* add documents for mp_layers.py

* update init.py

* update collective.py.

* update.

* update mp_ops.py

* update.

* update code style.

* update code style.
---
 python/paddle/distributed/collective.py       | 790 +-----------------
 .../distributed/communication/comm_utils.py   |  50 ++
 .../distributed/fleet/layers/mpu/__init__.py  |  24 +
 .../distributed/fleet/layers/mpu/mp_layers.py | 466 +++++++++++
 .../distributed/fleet/layers/mpu/mp_ops.py    | 772 +++++++++++++++++
 .../distributed/fleet/layers/mpu/random.py    | 243 ++++++
 .../parallel_layers/mp_layers.py              | 299 +------
 .../meta_parallel/parallel_layers/random.py   | 234 +-----
 python/setup.py.in                            |   2 +
 9 files changed, 1581 insertions(+), 1299 deletions(-)
 create mode 100644 python/paddle/distributed/communication/comm_utils.py
 create mode 100644 python/paddle/distributed/fleet/layers/mpu/__init__.py
 create mode 100644 python/paddle/distributed/fleet/layers/mpu/mp_layers.py
 create mode 100644 python/paddle/distributed/fleet/layers/mpu/mp_ops.py
 create mode 100644 python/paddle/distributed/fleet/layers/mpu/random.py

diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py
index 57731b8ad0e..5ceb046f550 100644
--- a/python/paddle/distributed/collective.py
+++ b/python/paddle/distributed/collective.py
@@ -40,46 +40,23 @@ import paddle.fluid.core as core
 from paddle import _C_ops, _legacy_C_ops
 import paddle.fluid.dygraph_utils as dygraph_utils
 import contextlib
+from .fleet.layers.mpu.mp_ops import split
+from .fleet.layers.mpu.mp_ops import _c_identity
+from .fleet.layers.mpu.mp_ops import _c_concat
+from .fleet.layers.mpu.mp_ops import _c_split
+from .fleet.layers.mpu.mp_ops import _mp_allreduce
+from .fleet.layers.mpu.mp_ops import _c_lookup_table
+from .fleet.layers.mpu.mp_ops import _Linear
+from .fleet.layers.mpu.mp_ops import _set_var_distributed
+from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
+from .fleet.layers.mpu.mp_ops import _linear
+from .fleet.layers.mpu.mp_ops import _parallel_linear
+from .fleet.layers.mpu.mp_ops import _parallel_embedding
+from .communication.comm_utils import ReduceOp
 
 __all__ = []
 
 
-class ReduceOp:
-    """
-    Specify the type of operation used for element-wise reductions.
-    It should be one of the following values:
-
-        ReduceOp.SUM
-
-        ReduceOp.MAX
-
-        ReduceOp.MIN
-
-        ReduceOp.PROD
-
-    Examples:
-        .. code-block:: python
-
-            # required: distributed
-            import paddle
-            import paddle.distributed as dist
-
-            dist.init_parallel_env()
-            if dist.get_rank() == 0:
-                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
-            else:
-                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
-            dist.all_reduce(data, op=dist.ReduceOp.SUM)
-            print(data)
-            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
-    """
-    SUM = 0
-    MAX = 1
-    MIN = 2
-    PROD = 3
-    AVG = 4
-
-
 class Group():
     """
     The abstract representation of group.
@@ -1259,747 +1236,6 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
                      })
 
 
-def _c_identity(tensor, group=None):
-    """
-    Return a copy of the tensor, mainly used with model parallel.
-
-    Args:
-        tensor (Tensor): The input Tensor. Its data type
-            should be float16, float32, float64, int32 or int64.
-        group (int): The id of the process group to work on.
-
-    Returns:
-        Tensor.
-    """
-    if group is not None and not group.is_member():
-        return
-    ring_id = 0 if group is None else group.id
-
-    if _non_static_mode():
-        return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
-                                        'ring_id', ring_id,
-                                        'use_model_parallel', True)
-    op_type = 'c_identity'
-    helper = LayerHelper(op_type, **locals())
-    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
-
-    check_variable_and_dtype(
-        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
-        '_c_identity')
-
-    helper.append_op(type=op_type,
-                     inputs={'X': tensor},
-                     outputs={'Out': out},
-                     attrs={
-                         'ring_id': ring_id,
-                         'use_calc_stream': True,
-                         'use_model_parallel': True,
-                     })
-    return out
-
-
-def _c_concat(tensor, group=None):
-    """
-    Return allgather of the tensor, mainly used with model parallel.
-
-    Args:
-        tensor (Tensor): The input Tensor. Its data type
-            should be float16, float32, float64, int32 or int64.
-        group (int): The id of the process group to work on.
-
-    Returns:
-        Tensor.
-    """
-    if group is not None and not group.is_member():
-        return
-    group = _get_default_group() if group is None else group
-    ring_id = group.id
-
-    global_rank = _get_global_env().rank
-    rank = group.rank
-    nranks = group.nranks
-
-    if _non_static_mode():
-        return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id,
-                                      'use_calc_stream', True, 'rank', rank,
-                                      'nranks', nranks, 'use_model_parallel',
-                                      True)
-
-    op_type = 'c_concat'
-    helper = LayerHelper(op_type, **locals())
-    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
-
-    check_variable_and_dtype(
-        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
-        '_c_concat')
-
-    helper.append_op(type=op_type,
-                     inputs={'X': tensor},
-                     outputs={'Out': out},
-                     attrs={
-                         'ring_id': ring_id,
-                         'use_calc_stream': True,
-                         'use_model_parallel': True,
-                         'nranks': nranks,
-                         'rank': rank
-                     })
-    return out
-
-
-def _c_split(tensor, group=None):
-    """
-    Split tensor evenly among all members, mainly used with model parallel.
-
-    Args:
-        tensor (Tensor): The input Tensor. Its data type
-            should be float16, float32, float64, int32 or int64.
-        rank (int): The rank of the current process.
-        group (int): The id of the process group to work on.
-
-    Returns:
-        Tensor.
-    """
-    if group is not None and not group.is_member():
-        return
-    ring_id = 0 if group is None else group.id
-
-    global_rank = _get_global_env().rank
-    rank = global_rank if group is None else group.get_group_rank(global_rank)
-    nranks = _get_global_env().world_size if group is None else group.nranks
-
-    if _non_static_mode():
-        return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
-                                     ring_id, 'rank', rank, 'nranks', nranks,
-                                     'use_model_parallel', True)
-
-    op_type = 'c_split'
-    helper = LayerHelper(op_type, **locals())
-    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
-
-    check_variable_and_dtype(
-        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
-        '_c_split')
-
-    helper.append_op(type=op_type,
-                     inputs={'X': tensor},
-                     outputs={'Out': out},
-                     attrs={
-                         'ring_id': ring_id,
-                         'use_calc_stream': True,
-                         'rank': rank,
-                         'nranks': nranks,
-                         'use_model_parallel': True,
-                     })
-    return out
-
-
-def _mp_allreduce(tensor,
-                  op=ReduceOp.SUM,
-                  group=None,
-                  use_calc_stream=True,
-                  use_model_parallel=True):
-    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
-    """
-    if group is not None and not group.is_member():
-        return
-
-    if in_dygraph_mode():
-        group = _get_default_group() if group is None else group
-        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
-
-        from paddle.autograd import PyLayer
-
-        class mp_allreduce_eager(PyLayer):
-
-            @staticmethod
-            def forward(ctx, tensor, group, use_calc_stream,
-                        use_model_parallel):
-                ctx.ring_id = group.id
-
-                if use_calc_stream:
-                    op_type = _get_reduce_op(op, "_mp_allreduce")
-                    group.process_group.allreduce_on_calc_stream(
-                        tensor, op_type)
-                    return tensor
-                else:
-                    return _legacy_C_ops.c_allreduce_sum_(
-                        tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
-                        ring_id, "use_model_parallel", use_model_parallel)
-
-            @staticmethod
-            def backward(ctx, dy):
-                return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True,
-                                                'ring_id', ctx.ring_id,
-                                                'use_model_parallel', True)
-
-        return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
-                                        use_model_parallel)
-
-    ring_id = 0 if group is None else group.id
-    if _in_legacy_dygraph():
-        if op == ReduceOp.SUM:
-            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
-                                                  use_calc_stream, 'ring_id',
-                                                  ring_id, "use_model_parallel",
-                                                  use_model_parallel)
-        else:
-            raise ValueError("Unknown parameter: {}.".format(op))
-
-    op_type = 'c_allreduce_sum'
-    helper = LayerHelper(op_type, **locals())
-    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
-
-    check_variable_and_dtype(
-        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
-        op_type)
-
-    helper.append_op(type=op_type,
-                     inputs={'X': tensor},
-                     outputs={'Out': out},
-                     attrs={
-                         'ring_id': ring_id,
-                         'use_calc_stream': use_calc_stream,
-                         'use_model_parallel': use_model_parallel,
-                     })
-    return out
-
-
-def _c_lookup_table(table, index, start_index=0, name=None):
-    """
-    Lookup table according to index.
-
-    Args:
-        table (Tensor): The input Tensor. Its data type
-            should be float16, float32, float64.
-        index (Tensor): The index to lookup table.
-        start_index (int): The initial index for table range.
-        name (string): The name of the api
-
-    Returns:
-        Tensor.
-    """
-    if _non_static_mode():
-        return _legacy_C_ops.c_embedding(table, index, "start_index",
-                                         start_index)
-
-    op_type = 'c_embedding'
-    helper = LayerHelper(op_type, **locals())
-    dtype = helper.input_dtype(input_param_name='table')
-    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
-    tmp = helper.create_variable_for_type_inference(dtype)
-    helper.append_op(type='c_embedding',
-                     inputs={
-                         'Ids': index,
-                         'W': table
-                     },
-                     outputs={'Out': tmp},
-                     attrs={"start_index": start_index})
-    return tmp
-
-
-class _Linear(layers.Layer):
-    """
-    Linear
-    """
-
-    def __init__(self,
-                 in_features,
-                 out_features,
-                 weight_attr=None,
-                 bias_attr=None,
-                 name=None):
-        super(_Linear, self).__init__()
-        self._dtype = self._helper.get_default_dtype()
-        self._weight_attr = weight_attr
-        self._bias_attr = bias_attr
-        self.weight = self.create_parameter(shape=[in_features, out_features],
-                                            attr=self._weight_attr,
-                                            dtype=self._dtype,
-                                            is_bias=False)
-        self.bias = self.create_parameter(shape=[out_features],
-                                          attr=self._bias_attr,
-                                          dtype=self._dtype,
-                                          is_bias=True)
-        self.name = name
-
-    def forward(self, input):
-        out = _linear(x=input,
-                      weight=self.weight,
-                      bias=self.bias,
-                      name=self.name)
-        return out
-
-    def extra_repr(self):
-        name_str = ', name={}'.format(self.name) if self.name else ''
-        return 'in_features={}, out_features={}, dtype={}{}'.format(
-            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)
-
-
-def _c_softmax_with_cross_entropy(logits,
-                                  label,
-                                  group=None,
-                                  return_softmax=False):
-    if group is not None and not group.is_member():
-        return
-    ring_id = 0 if group is None else group.id
-    global_rank = _get_global_env().rank
-    rank = global_rank if group is None else group.get_group_rank(global_rank)
-    nranks = _get_global_env().world_size if group is None else group.nranks
-
-    input_dims = len(list(logits.shape))
-    label_dims = len(list(label.shape))
-    if input_dims - 1 != label_dims and input_dims != label_dims:
-        raise ValueError(
-            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
-             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
-    if input_dims - 1 == label_dims:
-        label = paddle.unsqueeze(label, axis=-1)
-
-    if _non_static_mode():
-        softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
-            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
-        if not return_softmax:
-            return loss
-        else:
-            return loss, softmax
-
-    attrs = {
-        'ring_id': ring_id,
-        'rank': rank,
-        'nranks': nranks,
-    }
-    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
-    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
-    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
-    helper.append_op(type='c_softmax_with_cross_entropy',
-                     inputs={
-                         'Logits': logits,
-                         'Label': label
-                     },
-                     outputs={
-                         'Softmax': softmax,
-                         'Loss': loss
-                     },
-                     attrs=attrs)
-
-    if return_softmax:
-        return loss, softmax
-
-    return loss
-
-
-def _linear(x, weight, bias=None, name=None):
-    """
-    Fuction Linear
-    """
-    if _non_static_mode():
-        pre_bias = _varbase_creator(dtype=x.dtype)
-        _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False,
-                             'transpose_Y', False, "alpha", 1)
-        return dygraph_utils._append_bias_in_dygraph(pre_bias,
-                                                     bias,
-                                                     axis=len(x.shape) - 1)
-    else:
-        helper = LayerHelper('linear', **locals())
-        dtype = x.dtype
-        assert len(
-            x.shape) < 4, "X latitude is not supported greater than 3 now."
-
-        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
-                                 'linear')
-        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
-
-        inputs = {'X': [x], 'Y': [weight]}
-        attrs = {
-            'transpose_X': False,
-            'transpose_Y': False,
-            'alpha': 1,
-        }
-        tmp = helper.create_variable_for_type_inference(dtype)
-        helper.append_op(type='matmul_v2',
-                         inputs=inputs,
-                         outputs={'Out': tmp},
-                         attrs=attrs)
-        if bias is not None:
-            res = helper.create_variable_for_type_inference(dtype)
-            helper.append_op(type='elementwise_add',
-                             inputs={
-                                 'X': [tmp],
-                                 'Y': [bias]
-                             },
-                             outputs={'Out': [res]},
-                             attrs={'axis': len(x.shape) - 1})
-        else:
-            res = tmp
-        return res
-
-
-def _set_var_distributed(var):
-    if var is None:
-        return
-
-    var.is_distributed = True
-
-    # NOTE: use current_block and find_var_recursive to support while_loop
-    startup_block = paddle.static.default_startup_program().current_block()
-    main_block = paddle.static.default_main_program().current_block()
-    startup_block._find_var_recursive(var.name).is_distributed = True
-    main_block._find_var_recursive(var.name).is_distributed = True
-
-
-def _parallel_linear(x,
-                     num_rows,
-                     num_cols,
-                     axis,
-                     param_attr,
-                     bias_attr,
-                     gather_out,
-                     inner_rank,
-                     nranks,
-                     split_tensor,
-                     name,
-                     group=None):
-    """
-    Parallel Linear
-
-    axis the dimension of the parameter of linear layer.
-    axis = 0: the row dimension
-    axis = 1: the col dimension
-
-    """
-    if group is not None and not group.is_member():
-        return
-    ring_id = 0 if group is None else group.id
-
-    if axis == 0:
-        if split_tensor:
-            x = _c_split(x, group=group)
-    else:
-        x = _c_identity(x, group=group)
-
-    linear = paddle.nn.Linear(num_rows,
-                              num_cols,
-                              weight_attr=param_attr,
-                              bias_attr=bias_attr,
-                              name=name)
-
-    # NOTE: npu linear function use matmul_v2 but linear use matmul
-    linear_function = _linear if core.is_compiled_with_npu()\
-        else paddle.nn.functional.linear
-    linear_out = linear_function(
-        x,
-        linear.weight,
-        # NOTE(wangxi): row split, bias need add after allreduce
-        None if axis == 0 else linear.bias,
-        linear.name)
-
-    _set_var_distributed(linear.weight)
-    # set is_distributed for splited bias
-    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
-    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
-    if axis == 1 and linear._bias_attr != False:
-        _set_var_distributed(linear.bias)
-
-    if not gather_out: return linear_out
-
-    out_shape = list(linear_out.shape)
-    out_shape[0] *= 1 if axis == 0 else nranks
-    main_block = paddle.static.default_main_program().current_block()
-    out = main_block.create_var(
-        shape=out_shape,
-        dtype=linear_out.dtype,
-        type=linear_out.type,
-        lod_level=linear_out.lod_level,
-        persistable=False,
-        is_data=False,
-        need_check_feed=linear_out.desc.need_check_feed())
-    if axis == 0:
-        main_block.append_op(type='c_allreduce_sum',
-                             inputs={'X': linear_out},
-                             outputs={'Out': out},
-                             attrs={
-                                 'ring_id': ring_id,
-                                 'use_calc_stream': True,
-                                 'use_model_parallel': True
-                             })
-        if linear.bias is not None:
-            out = out + linear.bias
-    else:
-        main_block.append_op(type='c_concat',
-                             inputs={'X': linear_out},
-                             outputs={'Out': out},
-                             attrs={
-                                 'rank': inner_rank,
-                                 'ring_id': ring_id,
-                                 'nranks': nranks,
-                                 'use_calc_stream': True,
-                                 'use_model_parallel': True
-                             })
-    return out
-
-
-def _parallel_embedding(x,
-                        per_part_embeddings,
-                        origin_size,
-                        param_attr,
-                        inner_rank,
-                        num_partitions,
-                        name,
-                        group=None):
-    """
-    Parallel Embedding
-    """
-    if group is not None and not group.is_member():
-        return
-    ring_id = 0 if group is None else group.id
-
-    helper = LayerHelper("_parallel_embedding", **locals())
-
-    per_part_size = per_part_embeddings
-    rank = inner_rank
-
-    vocab_start_index = rank * per_part_size
-    dtype = helper.get_default_dtype()
-    size = [per_part_size, origin_size[1]]
-
-    weight = helper.create_parameter(attr=param_attr,
-                                     shape=size,
-                                     dtype=dtype,
-                                     is_bias=False)
-
-    if num_partitions == 1:
-        return paddle.nn.functional.embedding(x,
-                                              weight=weight,
-                                              padding_idx=None,
-                                              sparse=False,
-                                              name=name)
-
-    startup_block = paddle.static.default_startup_program().global_block()
-    main_block = paddle.static.default_main_program().global_block()
-    startup_block.vars[weight.name].is_distributed = True
-    main_block.vars[weight.name].is_distributed = True
-
-    output_parallel = paddle.distributed.collective._c_lookup_table(
-        weight, x, start_index=vocab_start_index, name=name)
-    out = paddle.distributed.collective._mp_allreduce(output_parallel,
-                                                      group=group,
-                                                      use_calc_stream=True,
-                                                      use_model_parallel=True)
-    return out
-
-
-def split(x,
-          size,
-          operation,
-          axis=0,
-          num_partitions=1,
-          gather_out=True,
-          weight_attr=None,
-          bias_attr=None,
-          name=None):
-    """
-
-    Split the weight of the specified operation into multiple devices
-    and do the computation in parallel.
-
-    Now the following three cases are supported.
-
-    Case 1: Parallel Embedding
-        The weight of the embedding operation is a NxM matrix with N rows and M columns.
-        With parallel embedding, the weight is split into num_partitions partitions, each
-        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
-        row as the padding idx.
-
-        Suppose we split the NxM weight into two partitons on device_0 and device_1
-        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
-        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
-        keep unchanged and all other values are changed to N/2 which is the padding index and
-        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
-        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
-        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
-        devices are sum-reduced.
-
-        The Embedding put on single card is as shown below:
-
-        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
-            :width: 800
-            :height: 350
-            :alt: single_embedding
-            :align: center
-
-        Parallel Embedding is shown as below:
-
-        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
-            :width: 800
-            :alt: split_embedding
-            :align: center
-
-    Case 2: Row Parallel Linear
-        The weight of the linear operation is a NxM matrix with N rows and M columns.
-        With row parallel linear, the weight is split into num_partitions partitions, each
-        of which is a matrix with N/num_partitions rows and M column.
-
-        The linear layer put on single card is shown as below, the input variable is represented by X,
-        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is
-        simple matrix multiplication operation, O = X * W.
-
-        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
-            :width: 800
-            :alt: single_linear
-            :align: center
-
-        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
-        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
-        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.
-
-        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
-            :width: 800
-            :alt: split_row
-            :align: center
-
-    Case 3: Column Parallel Linear
-        The weight of the linear operation is a NxM matrix with N rows and M columns.
-        With column parallel linear, the weight is split into num_paratitions partitions, each
-        of which is a matrix with N rows and M/num_partitions column.
-
-        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
-        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and
-        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output.
-
-        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
-            :width: 800
-            :alt: split_col
-            :align: center
-
-    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
-    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.
-
-    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
-            :width: 800
-            :alt: split_col_row
-            :align: center
-
-    Args:
-        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
-        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
-        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
-        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
-        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
-        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
-            on each partitions will be gathered after computation. Default: True.
-        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
-            weights(Parameter) of the specified operation. Default: None.
-        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
-            of the specified operation. Default: None.
-        name (str, Optional): The default value is None. Normally there is no need for user to set this
-            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        Tensor.
-
-    Examples:
-        .. code-block:: python
-
-            # required: distributed
-            import paddle
-            import paddle.distributed.fleet as fleet
-
-            paddle.enable_static()
-            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
-            fleet.init(is_collective=True)
-            data = paddle.randint(0, 8, shape=[10,4])
-            emb_out = paddle.distributed.split(
-                data,
-                (8, 8),
-                operation="embedding",
-                num_partitions=2)
-
-    """
-    assert isinstance(
-        size,
-        (list, tuple)), ("The type of size for "
-                         "paddle.distributed.split must be list or tuple.")
-    assert len(size) == 2, ("Number of elements in size of "
-                            "paddle.distributed.split must be two.")
-    assert isinstance(operation, str), ("The type of operation for "
-                                        "paddle.distributed.split must be str.")
-    supported_operations = [
-        'linear',
-        'embedding',
-    ]
-    assert operation in supported_operations, (
-        "The operation for "
-        "paddle.distributed.split must be one of {}.".format(
-            supported_operations))
-    if _non_static_mode():
-        raise ValueError(
-            "paddle.distributed.split cannot be used in dynamic "
-            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
-            "ParallelColumnLinear instead.")
-    else:
-        from .fleet import fleet
-        assert fleet._role_maker, ("To use paddle.distributed.split, "
-                                   "you must call fleet.init() firstly.")
-        rank = fleet.worker_index()
-        nranks = fleet.worker_num()
-
-    # rank within a model parallel group
-    inner_rank = rank % num_partitions
-
-    if operation == "embedding":
-        assert axis == 0, ("We only support to split the weight of embedding "
-                           "along the first axis now.")
-        assert size[0] % num_partitions == 0, \
-            "The length of the vocabulary must be divisible by num_partitions " \
-            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
-
-        per_part_size = size[0] // num_partitions
-        emb_out = _parallel_embedding(x,
-                                      per_part_size,
-                                      size,
-                                      weight_attr,
-                                      inner_rank,
-                                      num_partitions,
-                                      name,
-                                      group=None)
-        return emb_out
-    else:
-        should_split = False
-        if axis == 0:
-            assert size[0] % num_partitions == 0, (
-                "Number of rows of the weight for linear ({}) must be"
-                " divisible by num_partitions ({})".format(
-                    size[0], num_partitions))
-            per_part_size = size[0] // num_partitions
-            linear_size = (per_part_size, size[1])
-            if x.shape[-1] == size[0]: should_split = True
-
-        elif axis == 1:
-            assert size[1] % num_partitions == 0, (
-                "Number of column of the weight for linear ({}) must be"
-                " divisible by num_partitions ({})".format(
-                    size[1], num_partitions))
-            per_part_size = size[1] // num_partitions
-            linear_size = (size[0], per_part_size)
-        else:
-            raise ValueError("The value of axis must be 0 or 1, but the value "
-                             "given is {}.".format(axis))
-
-        linear_out = _parallel_linear(x,
-                                      linear_size[0],
-                                      linear_size[1],
-                                      axis,
-                                      weight_attr,
-                                      bias_attr,
-                                      gather_out,
-                                      inner_rank,
-                                      num_partitions,
-                                      should_split,
-                                      name=name,
-                                      group=None)
-        return linear_out
-
-
 def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
     """
     Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
diff --git a/python/paddle/distributed/communication/comm_utils.py b/python/paddle/distributed/communication/comm_utils.py
new file mode 100644
index 00000000000..62e1bcb4cca
--- /dev/null
+++ b/python/paddle/distributed/communication/comm_utils.py
@@ -0,0 +1,50 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class ReduceOp:
+    """
+
+    Specify the type of operation used for element-wise reductions.
+    It should be one of the following values:
+
+        ReduceOp.SUM
+
+        ReduceOp.MAX
+
+        ReduceOp.MIN
+
+        ReduceOp.PROD
+
+    Examples:
+        .. code-block:: python
+
+            # required: distributed
+            import paddle
+            import paddle.distributed as dist
+
+            dist.init_parallel_env()
+            if dist.get_rank() == 0:
+                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
+            else:
+                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
+            dist.all_reduce(data, op=dist.ReduceOp.SUM)
+            print(data)
+            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
+    """
+    SUM = 0
+    MAX = 1
+    MIN = 2
+    PROD = 3
+    AVG = 4
diff --git a/python/paddle/distributed/fleet/layers/mpu/__init__.py b/python/paddle/distributed/fleet/layers/mpu/__init__.py
new file mode 100644
index 00000000000..11b69702650
--- /dev/null
+++ b/python/paddle/distributed/fleet/layers/mpu/__init__.py
@@ -0,0 +1,24 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .mp_layers import VocabParallelEmbedding
+from .mp_layers import ColumnParallelLinear
+from .mp_layers import RowParallelLinear
+from .mp_layers import ParallelCrossEntropy
+
+from .random import RNGStatesTracker
+from .random import get_rng_state_tracker
+from .random import model_parallel_random_seed
+from .random import determinate_seed
+from .random import dropout
diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py
new file mode 100644
index 00000000000..2ba9ce9ed76
--- /dev/null
+++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py
@@ -0,0 +1,466 @@
+#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from . import mp_ops
+from paddle.fluid import core
+from paddle.fluid.dygraph.layers import Layer
+from .random import get_rng_state_tracker
+from paddle.nn import functional as F
+from paddle import framework
+from paddle.autograd import PyLayer
+from ...base import topology as tp
+
+__all__ = []
+
+# Follow this paper to achieve the file:
+# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
+# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
+
+
+def is_fused_matmul_bias_supported():
+    if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
+        return hasattr(core.ops, 'fused_gemm_epilogue')
+    else:
+        return False
+
+
+class VocabParallelEmbedding(Layer):
+    """Embedding mp parallelized in the vocabulary dimension.
+    this class is used for splitting embedding in mp group.
+
+    Args:
+        num_embeddings(int): One element which indicate the size of the dictionary of embeddings.
+        embedding_dim(int): One element which indicate the size of each embedding vector respectively.
+        weight_attr(ParamAttr|None): To specify the weight parameter property. Default: None, which means the
+            default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
+            user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
+            The local word vector needs to be transformed into numpy format, and the shape of local word
+            vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
+            is used to load custom or pre-trained word vectors. See code example for details.
+        mp_group(Group): The tensor parallel group.
+        name(str, optional): For detailed information, please refer
+               to :ref:`api_guide_Name`. Usually name is no need to set and
+               None by default.
+
+    Examples:
+        .. code-block:: python
+        import paddle
+        from paddle.distributed import fleet
+
+        class SimpleMPNet(paddle.nn.Layer):
+           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
+              super(SimpleMPNet, self).__init__()
+              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
+                    hidden_size,
+                    inner_size,
+                    gather_output=False,
+                    has_bias=True)
+
+              self.linear2 = fleet.meta_parallel.RowParallelLinear(
+                    inner_size,
+                    hidden_size,
+                    input_is_parallel=True,
+                    has_bias=True)
+
+              self.linear3 = paddle.nn.Linear(hidden_size, output_size)
+
+              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
+                                vocab_size,
+                                hidden_size)
+
+           def forward(self, x):
+              x = self.embedding(x)
+              x = self.linear1(x)
+              x = self.linear2(x)
+              x = self.linear3(x)
+              return x
+    """
+
+    def __init__(self,
+                 num_embeddings,
+                 embedding_dim,
+                 weight_attr=None,
+                 mp_group=None,
+                 name=None):
+        super(VocabParallelEmbedding, self).__init__()
+
+        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
+        ) if mp_group is None else mp_group
+        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
+        ) if mp_group is None else mp_group.nranks
+        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
+        ) if mp_group is None else mp_group.rank
+
+        self.origin_num_embeddings = num_embeddings
+        self.is_mp = (self.world_size > 1)
+
+        assert num_embeddings % self.world_size == 0, (
+            "The length of the vocabulary must be divisible by the parallelism degree of MP"
+        )
+
+        per_part_size = num_embeddings // self.world_size
+
+        self.vocab_start_index = self.rank * per_part_size
+        self._dtype = self._helper.get_default_dtype()
+        self._size = [per_part_size, embedding_dim]
+        self._weight_attr = weight_attr
+        self._name = name
+
+        if self.is_mp and paddle.in_dynamic_mode():
+            with get_rng_state_tracker().rng_state():
+                self.weight = self.create_parameter(attr=self._weight_attr,
+                                                    shape=self._size,
+                                                    dtype=self._dtype,
+                                                    is_bias=False)
+        else:
+            self.weight = self.create_parameter(attr=self._weight_attr,
+                                                shape=self._size,
+                                                dtype=self._dtype,
+                                                is_bias=False)
+
+        self.weight.is_distributed = True if self.is_mp else False
+
+    def forward(self, x):
+        if self.is_mp:
+            output_parallel = mp_ops._c_lookup_table(
+                self.weight,
+                x,
+                start_index=self.vocab_start_index,
+                name=self._name)
+            output = mp_ops._mp_allreduce(output_parallel,
+                                          group=self.model_parallel_group,
+                                          use_calc_stream=True,
+                                          use_model_parallel=True)
+        else:
+            output = F.embedding(x,
+                                 weight=self.weight,
+                                 padding_idx=None,
+                                 sparse=False,
+                                 name=self._name)
+        return output
+
+
+class ColumnParallelLinear(Layer):
+    """Linear layer with mp parallelized(column).
+    this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer.
+
+    Args:
+        in_features(int): The number of input units.
+        out_features(int): The number of output units.
+        weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
+            and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
+        has_bias(bool): whether to add bias.
+        gather_output(bool): whether to do allgahter for the output of each rank.
+        fuse_matmul_bias(bool): whether to fuse matmul and bias.
+        mp_group(Group): The tensor parallel group.
+        name(str, optional): Normally there is no need for user to set this parameter.
+            For detailed information, please refer to :ref:`api_guide_Name` .
+
+    Examples:
+        .. code-block:: python
+        import paddle
+        from paddle.distributed import fleet
+
+        class SimpleMPNet(paddle.nn.Layer):
+           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
+              super(SimpleMPNet, self).__init__()
+              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
+                    hidden_size,
+                    inner_size,
+                    gather_output=False,
+                    has_bias=True)
+
+              self.linear2 = fleet.meta_parallel.RowParallelLinear(
+                    inner_size,
+                    hidden_size,
+                    input_is_parallel=True,
+                    has_bias=True)
+
+              self.linear3 = paddle.nn.Linear(hidden_size, output_size)
+
+              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
+                                vocab_size,
+                                hidden_size)
+
+           def forward(self, x):
+              x = self.embedding(x)
+              x = self.linear1(x)
+              x = self.linear2(x)
+              x = self.linear3(x)
+              return x
+    """
+
+    def __init__(self,
+                 in_features,
+                 out_features,
+                 weight_attr=None,
+                 has_bias=None,
+                 gather_output=True,
+                 fuse_matmul_bias=False,
+                 mp_group=None,
+                 name=None):
+        super(ColumnParallelLinear, self).__init__()
+
+        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
+        ) if mp_group is None else mp_group
+        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
+        ) if mp_group is None else mp_group.nranks
+        self._name = name
+        self.is_mp = (self.world_size > 1)
+
+        self.gather_output = gather_output
+        assert out_features % self.world_size == 0, (
+            "Number of column of the weight for linear ({}) must be"
+            " divisible by model parallel size ({})".format(
+                out_features, self.world_size))
+        self.output_size_per_partition = out_features // self.world_size
+
+        self._weight_attr = weight_attr
+        self._dtype = self._helper.get_default_dtype()
+
+        if self.is_mp and paddle.in_dynamic_mode():
+            with get_rng_state_tracker().rng_state():
+                self.weight = self.create_parameter(
+                    shape=[in_features, self.output_size_per_partition],
+                    attr=self._weight_attr,
+                    dtype=self._dtype,
+                    is_bias=False)
+        else:
+            self.weight = self.create_parameter(
+                shape=[in_features, self.output_size_per_partition],
+                attr=self._weight_attr,
+                dtype=self._dtype,
+                is_bias=False)
+
+        self.weight.is_distributed = True if self.is_mp else False
+
+        if has_bias:
+            # initialize bias to zero like Megatron
+            self.bias = self.create_parameter(
+                shape=[self.output_size_per_partition],
+                attr=paddle.nn.initializer.Constant(value=0.0),
+                dtype=self._dtype,
+                is_bias=True)
+            self.bias.is_distributed = True if self.is_mp else False
+        else:
+            self.bias = None
+
+        self.linear = F.linear
+
+        if fuse_matmul_bias:
+            if not is_fused_matmul_bias_supported():
+                raise NotImplementedError(
+                    "You set fuse_matmul_bias=True in ColumnParallelLinear, "
+                    "however, the paddle you are using not support this operation. "
+                    "Please set fuse_matmul_bias=False or use paddle compiled "
+                    "with cuda 11.6 or higher.")
+            from paddle.incubate.nn.functional import fused_linear
+            self.linear = fused_linear
+
+    def forward(self, x):
+        # use inner api to process identity
+        if self.is_mp:
+            input_parallel = mp_ops._c_identity(x,
+                                                group=self.model_parallel_group)
+        else:
+            input_parallel = x
+
+        output_parallel = self.linear(input_parallel,
+                                      self.weight,
+                                      self.bias,
+                                      name=self._name)
+
+        if self.gather_output and self.is_mp:
+            output = mp_ops._c_concat(output_parallel,
+                                      group=self.model_parallel_group)
+        else:
+            output = output_parallel
+        return output
+
+
+class RowParallelLinear(Layer):
+    """Linear layer with mp parallelized(row).
+    this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
+
+    Args:
+        in_features(int): The number of input units.
+        out_features(int): The number of output units.
+        weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
+            and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
+        has_bias(bool): whether to add bias.
+        input_is_parallel(bool): whether the input has alreadly been splitted across the mp group.
+        fuse_matmul_bias(bool): whether to fuse matmul and bias.
+        mp_group(Group): The tensor parallel group.
+        name(str, optional): Normally there is no need for user to set this parameter.
+            For detailed information, please refer to :ref:`api_guide_Name` .
+
+    Examples:
+        .. code-block:: python
+        import paddle
+        from paddle.distributed import fleet
+
+        class SimpleMPNet(paddle.nn.Layer):
+           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
+              super(SimpleMPNet, self).__init__()
+              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
+                    hidden_size,
+                    inner_size,
+                    gather_output=False,
+                    has_bias=True)
+
+              self.linear2 = fleet.meta_parallel.RowParallelLinear(
+                    inner_size,
+                    hidden_size,
+                    input_is_parallel=True,
+                    has_bias=True)
+
+              self.linear3 = paddle.nn.Linear(hidden_size, output_size)
+
+              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
+                                vocab_size,
+                                hidden_size)
+
+           def forward(self, x):
+              x = self.embedding(x)
+              x = self.linear1(x)
+              x = self.linear2(x)
+              x = self.linear3(x)
+              return x
+    """
+
+    def __init__(self,
+                 in_features,
+                 out_features,
+                 weight_attr=None,
+                 has_bias=True,
+                 input_is_parallel=False,
+                 fuse_matmul_bias=False,
+                 mp_group=None,
+                 name=None):
+        super(RowParallelLinear, self).__init__()
+
+        self.in_features = in_features
+        self.out_features = out_features
+        self.input_is_parallel = input_is_parallel
+        self._weight_attr = weight_attr
+        self._dtype = self._helper.get_default_dtype()
+        self._name = name
+
+        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
+        ) if mp_group is None else mp_group
+        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
+        ) if mp_group is None else mp_group.nranks
+        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
+        ) if mp_group is None else mp_group.rank
+
+        self.is_mp = (self.world_size > 1)
+        assert in_features % self.world_size == 0, (
+            "Number of row of the weight for linear ({}) must be"
+            " divisible by model parallel size ({})".format(
+                in_features, self.world_size))
+
+        self.input_size_per_partition = in_features // self.world_size
+
+        if self.is_mp and paddle.in_dynamic_mode():
+            with get_rng_state_tracker().rng_state():
+                self.weight = self.create_parameter(
+                    shape=[self.input_size_per_partition, self.out_features],
+                    attr=self._weight_attr,
+                    dtype=self._dtype,
+                    is_bias=False)
+        else:
+            self.weight = self.create_parameter(
+                shape=[self.input_size_per_partition, self.out_features],
+                attr=self._weight_attr,
+                dtype=self._dtype,
+                is_bias=False)
+
+        self.weight.is_distributed = True if self.is_mp else False
+
+        if has_bias:
+            self.bias = self.create_parameter(
+                shape=[self.out_features],
+                attr=paddle.nn.initializer.Constant(value=0.0),
+                dtype=self._dtype,
+                is_bias=True)
+        else:
+            self.bias = None
+
+        self.linear = F.linear
+
+        if fuse_matmul_bias:
+            if not is_fused_matmul_bias_supported():
+                raise NotImplementedError(
+                    "You set fuse_matmul_bias=True in RowParallelLinear, "
+                    "however, the paddle you are using not support this operation. "
+                    "Please set fuse_matmul_bias=False or use paddle compiled "
+                    "with cuda 11.6 or higher.")
+            from paddle.incubate.nn.functional import fused_linear
+            self.linear = fused_linear
+
+    def forward(self, x):
+        if self.input_is_parallel or (not self.is_mp):
+            input_parallel = x
+        else:
+            # split last dim
+            input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)
+
+        if self.is_mp:
+            output_parallel = self.linear(input_parallel,
+                                          self.weight,
+                                          name=self._name)
+            output_ = mp_ops._mp_allreduce(output_parallel,
+                                           group=self.model_parallel_group,
+                                           use_calc_stream=True,
+                                           use_model_parallel=True)
+            output = output_ + self.bias if self.bias is not None else output_
+        else:
+            output = self.linear(input_parallel,
+                                 self.weight,
+                                 self.bias,
+                                 name=self._name)
+
+        return output
+
+
+class ParallelCrossEntropy(Layer):
+    """CrossEntropy with mp parallelized.
+    this class is used for splitting softmax cross entropy in mp group.
+
+    Args:
+        mp_group(Group): The tensor parallel group.
+        name(str, optional): Normally there is no need for user to set this parameter.
+            For detailed information, please refer to :ref:`api_guide_Name` .
+
+    Examples:
+        .. code-block:: python
+        loss_func = ParallelCrossEntropy()
+        loss = loss_func(img, lable)
+    """
+
+    def __init__(self, mp_group=None, name=None):
+        super(ParallelCrossEntropy, self).__init__()
+        self.name = name
+        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
+        ) if mp_group is None else mp_group
+        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
+        ) if mp_group is None else mp_group.nranks
+        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
+        ) if mp_group is None else mp_group.rank
+
+    def forward(self, input, label):
+        loss = mp_ops._c_softmax_with_cross_entropy(
+            input, label, group=self.model_parallel_group)
+        return loss
diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py
new file mode 100644
index 00000000000..dc4dc05c7ba
--- /dev/null
+++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py
@@ -0,0 +1,772 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle import _C_ops, _legacy_C_ops
+from paddle.fluid import core
+from paddle.fluid.framework import _non_static_mode
+from paddle.fluid.framework import _in_legacy_dygraph
+from paddle.fluid.framework import in_dygraph_mode
+from paddle.fluid.layer_helper import LayerHelper
+from paddle.fluid.data_feeder import check_variable_and_dtype
+from paddle.fluid.dygraph import layers
+from paddle.distributed import collective
+from ....communication.comm_utils import ReduceOp
+from paddle.fluid.data_feeder import check_dtype
+import paddle.fluid.dygraph_utils as dygraph_utils
+
+
+def _c_identity(tensor, group=None):
+    """
+    Return a copy of the tensor, mainly used with model parallel.
+
+    Args:
+        tensor (Tensor): The input Tensor. Its data type
+            should be float16, float32, float64, int32 or int64.
+        group (int): The id of the process group to work on.
+
+    Returns:
+        Tensor.
+    """
+    if group is not None and not group.is_member():
+        return
+    ring_id = 0 if group is None else group.id
+
+    if _non_static_mode():
+        return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
+                                        'ring_id', ring_id,
+                                        'use_model_parallel', True)
+    op_type = 'c_identity'
+    helper = LayerHelper(op_type, **locals())
+    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
+
+    check_variable_and_dtype(
+        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
+        '_c_identity')
+
+    helper.append_op(type=op_type,
+                     inputs={'X': tensor},
+                     outputs={'Out': out},
+                     attrs={
+                         'ring_id': ring_id,
+                         'use_calc_stream': True,
+                         'use_model_parallel': True,
+                     })
+    return out
+
+
+def _c_concat(tensor, group=None):
+    """
+    Return allgather of the tensor, mainly used with model parallel.
+
+    Args:
+        tensor (Tensor): The input Tensor. Its data type
+            should be float16, float32, float64, int32 or int64.
+        group (int): The id of the process group to work on.
+
+    Returns:
+        Tensor.
+    """
+    if group is not None and not group.is_member():
+        return
+    group = collective._get_default_group() if group is None else group
+    ring_id = group.id
+
+    global_rank = collective._get_global_env().rank
+    rank = group.rank
+    nranks = group.nranks
+
+    if _non_static_mode():
+        return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id,
+                                      'use_calc_stream', True, 'rank', rank,
+                                      'nranks', nranks, 'use_model_parallel',
+                                      True)
+
+    op_type = 'c_concat'
+    helper = LayerHelper(op_type, **locals())
+    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
+
+    check_variable_and_dtype(
+        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
+        '_c_concat')
+
+    helper.append_op(type=op_type,
+                     inputs={'X': tensor},
+                     outputs={'Out': out},
+                     attrs={
+                         'ring_id': ring_id,
+                         'use_calc_stream': True,
+                         'use_model_parallel': True,
+                         'nranks': nranks,
+                         'rank': rank
+                     })
+    return out
+
+
+def _c_split(tensor, group=None):
+    """
+    Split tensor evenly among all members, mainly used with model parallel.
+
+    Args:
+        tensor (Tensor): The input Tensor. Its data type
+            should be float16, float32, float64, int32 or int64.
+        rank (int): The rank of the current process.
+        group (int): The id of the process group to work on.
+
+    Returns:
+        Tensor.
+    """
+    if group is not None and not group.is_member():
+        return
+    ring_id = 0 if group is None else group.id
+
+    global_rank = collective._get_global_env().rank
+    rank = global_rank if group is None else group.get_group_rank(global_rank)
+    nranks = collective._get_global_env(
+    ).world_size if group is None else group.nranks
+
+    if _non_static_mode():
+        return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
+                                     ring_id, 'rank', rank, 'nranks', nranks,
+                                     'use_model_parallel', True)
+
+    op_type = 'c_split'
+    helper = LayerHelper(op_type, **locals())
+    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
+
+    check_variable_and_dtype(
+        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
+        '_c_split')
+
+    helper.append_op(type=op_type,
+                     inputs={'X': tensor},
+                     outputs={'Out': out},
+                     attrs={
+                         'ring_id': ring_id,
+                         'use_calc_stream': True,
+                         'rank': rank,
+                         'nranks': nranks,
+                         'use_model_parallel': True,
+                     })
+    return out
+
+
+def _mp_allreduce(tensor,
+                  op=ReduceOp.SUM,
+                  group=None,
+                  use_calc_stream=True,
+                  use_model_parallel=True):
+    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
+    """
+    if group is not None and not group.is_member():
+        return
+
+    if in_dygraph_mode():
+        group = collective._get_default_group() if group is None else group
+        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
+
+        from paddle.autograd import PyLayer
+
+        class mp_allreduce_eager(PyLayer):
+
+            @staticmethod
+            def forward(ctx, tensor, group, use_calc_stream,
+                        use_model_parallel):
+                ctx.ring_id = group.id
+
+                if use_calc_stream:
+                    op_type = collective._get_reduce_op(op, "_mp_allreduce")
+                    group.process_group.allreduce_on_calc_stream(
+                        tensor, op_type)
+                    return tensor
+                else:
+                    return _legacy_C_ops.c_allreduce_sum_(
+                        tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
+                        ring_id, "use_model_parallel", use_model_parallel)
+
+            @staticmethod
+            def backward(ctx, dy):
+                return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True,
+                                                'ring_id', ctx.ring_id,
+                                                'use_model_parallel', True)
+
+        return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
+                                        use_model_parallel)
+
+    ring_id = 0 if group is None else group.id
+    if _in_legacy_dygraph():
+        if op == ReduceOp.SUM:
+            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
+                                                  use_calc_stream, 'ring_id',
+                                                  ring_id, "use_model_parallel",
+                                                  use_model_parallel)
+        else:
+            raise ValueError("Unknown parameter: {}.".format(op))
+
+    op_type = 'c_allreduce_sum'
+    helper = LayerHelper(op_type, **locals())
+    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
+
+    check_variable_and_dtype(
+        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
+        op_type)
+
+    helper.append_op(type=op_type,
+                     inputs={'X': tensor},
+                     outputs={'Out': out},
+                     attrs={
+                         'ring_id': ring_id,
+                         'use_calc_stream': use_calc_stream,
+                         'use_model_parallel': use_model_parallel,
+                     })
+    return out
+
+
+def _c_lookup_table(table, index, start_index=0, name=None):
+    """
+    Lookup table according to index.
+
+    Args:
+        table (Tensor): The input Tensor. Its data type
+            should be float16, float32, float64.
+        index (Tensor): The index to lookup table.
+        start_index (int): The initial index for table range.
+        name (string): The name of the api
+
+    Returns:
+        Tensor.
+    """
+    if _non_static_mode():
+        return _legacy_C_ops.c_embedding(table, index, "start_index",
+                                         start_index)
+
+    op_type = 'c_embedding'
+    helper = LayerHelper(op_type, **locals())
+    dtype = helper.input_dtype(input_param_name='table')
+    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
+    tmp = helper.create_variable_for_type_inference(dtype)
+    helper.append_op(type='c_embedding',
+                     inputs={
+                         'Ids': index,
+                         'W': table
+                     },
+                     outputs={'Out': tmp},
+                     attrs={"start_index": start_index})
+    return tmp
+
+
+class _Linear(layers.Layer):
+    """
+    Linear
+    """
+
+    def __init__(self,
+                 in_features,
+                 out_features,
+                 weight_attr=None,
+                 bias_attr=None,
+                 name=None):
+        super(_Linear, self).__init__()
+        self._dtype = self._helper.get_default_dtype()
+        self._weight_attr = weight_attr
+        self._bias_attr = bias_attr
+        self.weight = self.create_parameter(shape=[in_features, out_features],
+                                            attr=self._weight_attr,
+                                            dtype=self._dtype,
+                                            is_bias=False)
+        self.bias = self.create_parameter(shape=[out_features],
+                                          attr=self._bias_attr,
+                                          dtype=self._dtype,
+                                          is_bias=True)
+        self.name = name
+
+    def forward(self, input):
+        out = _linear(x=input,
+                      weight=self.weight,
+                      bias=self.bias,
+                      name=self.name)
+        return out
+
+    def extra_repr(self):
+        name_str = ', name={}'.format(self.name) if self.name else ''
+        return 'in_features={}, out_features={}, dtype={}{}'.format(
+            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)
+
+
+def _c_softmax_with_cross_entropy(logits,
+                                  label,
+                                  group=None,
+                                  return_softmax=False):
+    if group is not None and not group.is_member():
+        return
+    ring_id = 0 if group is None else group.id
+    global_rank = collective._get_global_env().rank
+    rank = global_rank if group is None else group.get_group_rank(global_rank)
+    nranks = collective._get_global_env(
+    ).world_size if group is None else group.nranks
+
+    input_dims = len(list(logits.shape))
+    label_dims = len(list(label.shape))
+    if input_dims - 1 != label_dims and input_dims != label_dims:
+        raise ValueError(
+            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
+             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
+    if input_dims - 1 == label_dims:
+        label = paddle.unsqueeze(label, axis=-1)
+
+    if _non_static_mode():
+        softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
+            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
+        if not return_softmax:
+            return loss
+        else:
+            return loss, softmax
+
+    attrs = {
+        'ring_id': ring_id,
+        'rank': rank,
+        'nranks': nranks,
+    }
+    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
+    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
+    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
+    helper.append_op(type='c_softmax_with_cross_entropy',
+                     inputs={
+                         'Logits': logits,
+                         'Label': label
+                     },
+                     outputs={
+                         'Softmax': softmax,
+                         'Loss': loss
+                     },
+                     attrs=attrs)
+
+    if return_softmax:
+        return loss, softmax
+
+    return loss
+
+
+def _linear(x, weight, bias=None, name=None):
+    """
+    Fuction Linear
+    """
+    if _non_static_mode():
+        pre_bias = _varbase_creator(dtype=x.dtype)
+        _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False,
+                             'transpose_Y', False, "alpha", 1)
+        return dygraph_utils._append_bias_in_dygraph(pre_bias,
+                                                     bias,
+                                                     axis=len(x.shape) - 1)
+    else:
+        helper = LayerHelper('linear', **locals())
+        dtype = x.dtype
+        assert len(
+            x.shape) < 4, "X latitude is not supported greater than 3 now."
+
+        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
+                                 'linear')
+        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
+
+        inputs = {'X': [x], 'Y': [weight]}
+        attrs = {
+            'transpose_X': False,
+            'transpose_Y': False,
+            'alpha': 1,
+        }
+        tmp = helper.create_variable_for_type_inference(dtype)
+        helper.append_op(type='matmul_v2',
+                         inputs=inputs,
+                         outputs={'Out': tmp},
+                         attrs=attrs)
+        if bias is not None:
+            res = helper.create_variable_for_type_inference(dtype)
+            helper.append_op(type='elementwise_add',
+                             inputs={
+                                 'X': [tmp],
+                                 'Y': [bias]
+                             },
+                             outputs={'Out': [res]},
+                             attrs={'axis': len(x.shape) - 1})
+        else:
+            res = tmp
+        return res
+
+
+def _set_var_distributed(var):
+    if var is None:
+        return
+
+    var.is_distributed = True
+
+    # NOTE: use current_block and find_var_recursive to support while_loop
+    startup_block = paddle.static.default_startup_program().current_block()
+    main_block = paddle.static.default_main_program().current_block()
+    startup_block._find_var_recursive(var.name).is_distributed = True
+    main_block._find_var_recursive(var.name).is_distributed = True
+
+
+def _parallel_linear(x,
+                     num_rows,
+                     num_cols,
+                     axis,
+                     param_attr,
+                     bias_attr,
+                     gather_out,
+                     inner_rank,
+                     nranks,
+                     split_tensor,
+                     name,
+                     group=None):
+    """
+    Parallel Linear
+
+    axis the dimension of the parameter of linear layer.
+    axis = 0: the row dimension
+    axis = 1: the col dimension
+
+    """
+    if group is not None and not group.is_member():
+        return
+    ring_id = 0 if group is None else group.id
+
+    if axis == 0:
+        if split_tensor:
+            x = _c_split(x, group=group)
+    else:
+        x = _c_identity(x, group=group)
+
+    linear = paddle.nn.Linear(num_rows,
+                              num_cols,
+                              weight_attr=param_attr,
+                              bias_attr=bias_attr,
+                              name=name)
+
+    # NOTE: npu linear function use matmul_v2 but linear use matmul
+    linear_function = _linear if core.is_compiled_with_npu()\
+        else paddle.nn.functional.linear
+    linear_out = linear_function(
+        x,
+        linear.weight,
+        # NOTE(wangxi): row split, bias need add after allreduce
+        None if axis == 0 else linear.bias,
+        linear.name)
+
+    _set_var_distributed(linear.weight)
+    # set is_distributed for splited bias
+    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
+    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
+    if axis == 1 and linear._bias_attr != False:
+        _set_var_distributed(linear.bias)
+
+    if not gather_out: return linear_out
+
+    out_shape = list(linear_out.shape)
+    out_shape[0] *= 1 if axis == 0 else nranks
+    main_block = paddle.static.default_main_program().current_block()
+    out = main_block.create_var(
+        shape=out_shape,
+        dtype=linear_out.dtype,
+        type=linear_out.type,
+        lod_level=linear_out.lod_level,
+        persistable=False,
+        is_data=False,
+        need_check_feed=linear_out.desc.need_check_feed())
+    if axis == 0:
+        main_block.append_op(type='c_allreduce_sum',
+                             inputs={'X': linear_out},
+                             outputs={'Out': out},
+                             attrs={
+                                 'ring_id': ring_id,
+                                 'use_calc_stream': True,
+                                 'use_model_parallel': True
+                             })
+        if linear.bias is not None:
+            out = out + linear.bias
+    else:
+        main_block.append_op(type='c_concat',
+                             inputs={'X': linear_out},
+                             outputs={'Out': out},
+                             attrs={
+                                 'rank': inner_rank,
+                                 'ring_id': ring_id,
+                                 'nranks': nranks,
+                                 'use_calc_stream': True,
+                                 'use_model_parallel': True
+                             })
+    return out
+
+
+def _parallel_embedding(x,
+                        per_part_embeddings,
+                        origin_size,
+                        param_attr,
+                        inner_rank,
+                        num_partitions,
+                        name,
+                        group=None):
+    """
+    Parallel Embedding
+    """
+    if group is not None and not group.is_member():
+        return
+    ring_id = 0 if group is None else group.id
+
+    helper = LayerHelper("_parallel_embedding", **locals())
+
+    per_part_size = per_part_embeddings
+    rank = inner_rank
+
+    vocab_start_index = rank * per_part_size
+    dtype = helper.get_default_dtype()
+    size = [per_part_size, origin_size[1]]
+
+    weight = helper.create_parameter(attr=param_attr,
+                                     shape=size,
+                                     dtype=dtype,
+                                     is_bias=False)
+
+    if num_partitions == 1:
+        return paddle.nn.functional.embedding(x,
+                                              weight=weight,
+                                              padding_idx=None,
+                                              sparse=False,
+                                              name=name)
+
+    startup_block = paddle.static.default_startup_program().global_block()
+    main_block = paddle.static.default_main_program().global_block()
+    startup_block.vars[weight.name].is_distributed = True
+    main_block.vars[weight.name].is_distributed = True
+
+    output_parallel = _c_lookup_table(weight,
+                                      x,
+                                      start_index=vocab_start_index,
+                                      name=name)
+    out = _mp_allreduce(output_parallel,
+                        group=group,
+                        use_calc_stream=True,
+                        use_model_parallel=True)
+    return out
+
+
+def split(x,
+          size,
+          operation,
+          axis=0,
+          num_partitions=1,
+          gather_out=True,
+          weight_attr=None,
+          bias_attr=None,
+          name=None):
+    """
+
+    Split the weight of the specified operation into multiple devices
+    and do the computation in parallel.
+
+    Now the following three cases are supported.
+
+    Case 1: Parallel Embedding
+        The weight of the embedding operation is a NxM matrix with N rows and M columns.
+        With parallel embedding, the weight is split into num_partitions partitions, each
+        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
+        row as the padding idx.
+
+        Suppose we split the NxM weight into two partitons on device_0 and device_1
+        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
+        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
+        keep unchanged and all other values are changed to N/2 which is the padding index and
+        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
+        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
+        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
+        devices are sum-reduced.
+
+        The Embedding put on single card is as shown below:
+
+        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
+            :width: 800
+            :height: 350
+            :alt: single_embedding
+            :align: center
+
+        Parallel Embedding is shown as below:
+
+        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
+            :width: 800
+            :alt: split_embedding
+            :align: center
+
+    Case 2: Row Parallel Linear
+        The weight of the linear operation is a NxM matrix with N rows and M columns.
+        With row parallel linear, the weight is split into num_partitions partitions, each
+        of which is a matrix with N/num_partitions rows and M column.
+
+        The linear layer put on single card is shown as below, the input variable is represented by X,
+        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is
+        simple matrix multiplication operation, O = X * W.
+
+        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
+            :width: 800
+            :alt: single_linear
+            :align: center
+
+        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
+        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
+        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.
+
+        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
+            :width: 800
+            :alt: split_row
+            :align: center
+
+    Case 3: Column Parallel Linear
+        The weight of the linear operation is a NxM matrix with N rows and M columns.
+        With column parallel linear, the weight is split into num_paratitions partitions, each
+        of which is a matrix with N rows and M/num_partitions column.
+
+        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
+        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and
+        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output.
+
+        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
+            :width: 800
+            :alt: split_col
+            :align: center
+
+    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
+    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.
+
+    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
+            :width: 800
+            :alt: split_col_row
+            :align: center
+
+    Args:
+        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
+        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
+        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
+        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
+        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
+        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
+            on each partitions will be gathered after computation. Default: True.
+        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
+            weights(Parameter) of the specified operation. Default: None.
+        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
+            of the specified operation. Default: None.
+        name (str, Optional): The default value is None. Normally there is no need for user to set this
+            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        Tensor.
+
+    Examples:
+        .. code-block:: python
+
+            # required: distributed
+            import paddle
+            import paddle.distributed.fleet as fleet
+
+            paddle.enable_static()
+            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
+            fleet.init(is_collective=True)
+            data = paddle.randint(0, 8, shape=[10,4])
+            emb_out = paddle.distributed.split(
+                data,
+                (8, 8),
+                operation="embedding",
+                num_partitions=2)
+
+    """
+    assert isinstance(
+        size,
+        (list, tuple)), ("The type of size for "
+                         "paddle.distributed.split must be list or tuple.")
+    assert len(size) == 2, ("Number of elements in size of "
+                            "paddle.distributed.split must be two.")
+    assert isinstance(operation, str), ("The type of operation for "
+                                        "paddle.distributed.split must be str.")
+    supported_operations = [
+        'linear',
+        'embedding',
+    ]
+    assert operation in supported_operations, (
+        "The operation for "
+        "paddle.distributed.split must be one of {}.".format(
+            supported_operations))
+    if _non_static_mode():
+        raise ValueError(
+            "paddle.distributed.split cannot be used in dynamic "
+            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
+            "ParallelColumnLinear instead.")
+    else:
+        from paddle.distributed.fleet import fleet
+        assert fleet._role_maker, ("To use paddle.distributed.split, "
+                                   "you must call fleet.init() firstly.")
+        rank = fleet.worker_index()
+        nranks = fleet.worker_num()
+
+    # rank within a model parallel group
+    inner_rank = rank % num_partitions
+
+    if operation == "embedding":
+        assert axis == 0, ("We only support to split the weight of embedding "
+                           "along the first axis now.")
+        assert size[0] % num_partitions == 0, \
+            "The length of the vocabulary must be divisible by num_partitions " \
+            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
+
+        per_part_size = size[0] // num_partitions
+        emb_out = _parallel_embedding(x,
+                                      per_part_size,
+                                      size,
+                                      weight_attr,
+                                      inner_rank,
+                                      num_partitions,
+                                      name,
+                                      group=None)
+        return emb_out
+    else:
+        should_split = False
+        if axis == 0:
+            assert size[0] % num_partitions == 0, (
+                "Number of rows of the weight for linear ({}) must be"
+                " divisible by num_partitions ({})".format(
+                    size[0], num_partitions))
+            per_part_size = size[0] // num_partitions
+            linear_size = (per_part_size, size[1])
+            if x.shape[-1] == size[0]: should_split = True
+
+        elif axis == 1:
+            assert size[1] % num_partitions == 0, (
+                "Number of column of the weight for linear ({}) must be"
+                " divisible by num_partitions ({})".format(
+                    size[1], num_partitions))
+            per_part_size = size[1] // num_partitions
+            linear_size = (size[0], per_part_size)
+        else:
+            raise ValueError("The value of axis must be 0 or 1, but the value "
+                             "given is {}.".format(axis))
+
+        linear_out = _parallel_linear(x,
+                                      linear_size[0],
+                                      linear_size[1],
+                                      axis,
+                                      weight_attr,
+                                      bias_attr,
+                                      gather_out,
+                                      inner_rank,
+                                      num_partitions,
+                                      should_split,
+                                      name=name,
+                                      group=None)
+        return linear_out
diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py
new file mode 100644
index 00000000000..7577be6253c
--- /dev/null
+++ b/python/paddle/distributed/fleet/layers/mpu/random.py
@@ -0,0 +1,243 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numpy as np
+import contextlib
+from paddle import _C_ops, _legacy_C_ops
+from paddle.fluid import core
+from paddle.fluid.data_feeder import check_variable_and_dtype
+from paddle.fluid.framework import _non_static_mode, default_main_program, Variable
+from paddle.fluid.layer_helper import LayerHelper
+
+__all__ = []
+
+MODEL_PARALLEL_RNG = 'model_parallel_rng'
+
+# This file is inspired by Megatron to control random states for MP:
+# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py
+
+
+class RNGStatesTracker:
+    """
+    Tracker the RNG states.
+    """
+
+    def __init__(self):
+        # Map from name to the rng state.
+        self.states_ = {}
+        self.seeds_ = set()
+
+    def reset(self):
+        self.states_ = {}
+        self.seeds_ = set()
+
+    def add(self, name, seed):
+        if seed in self.seeds_:
+            raise ValueError('seed {} already exists'.format(seed))
+        self.seeds_.add(seed)
+        if name in self.states_:
+            raise ValueError('state {} already exists'.format(name))
+        orig_rng_state = paddle.get_cuda_rng_state()
+        paddle.seed(seed)
+        self.states_[name] = paddle.get_cuda_rng_state()
+        paddle.set_cuda_rng_state(orig_rng_state)
+
+    def get_states_tracker(self):
+        states = {}
+        for name in self.states_:
+            states[name] = self.states_[name]
+        return states
+
+    def set_states_tracker(self, states):
+        self.states_ = states
+
+    @contextlib.contextmanager
+    def rng_state(self, name=MODEL_PARALLEL_RNG):
+        if name not in self.states_:
+            raise ValueError('state {} does not exist'.format(name))
+        orig_cuda_rng_state = paddle.get_cuda_rng_state()
+        paddle.set_cuda_rng_state(self.states_[name])
+        try:
+            yield
+        finally:
+            self.states_[name] = paddle.get_cuda_rng_state()
+            paddle.set_cuda_rng_state(orig_cuda_rng_state)
+
+
+RNG_STATE_TRACKER = RNGStatesTracker()
+
+
+def get_rng_state_tracker():
+    return RNG_STATE_TRACKER
+
+
+def model_parallel_random_seed(seed=None):
+    import paddle.distributed.fleet as fleet
+    hcg = fleet.get_hybrid_communicate_group()
+    rank = hcg.get_model_parallel_rank()
+
+    if seed:
+        global_seed = seed
+        local_seed = seed * 1024 + rank * 100
+    else:
+        global_seed = np.random.randint(0, 655350)
+        local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
+
+    RNG_STATE_TRACKER.reset()
+    RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
+    paddle.seed(global_seed)
+
+
+def determinate_seed(rng_name):
+    assert rng_name is not None and rng_name != ""
+    helper = LayerHelper('seed', **locals())
+    out = helper.create_variable_for_type_inference(dtype=paddle.int32)
+    # set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
+    helper.append_op(type='seed',
+                     outputs={'Out': out},
+                     attrs={
+                         'deterministic': True,
+                         'rng_name': rng_name,
+                         'force_cpu': True
+                     })
+    return out
+
+
+def dropout(x,
+            p=0.5,
+            axis=None,
+            rng_name=None,
+            training=True,
+            mode="upscale_in_train",
+            name=None):
+    """
+    Dropout is a regularization technique for reducing overfitting by preventing
+    neuron co-adaption during training. The dropout operator randomly sets the
+    outputs of some units to zero, while upscale others according to the given
+    dropout probability.
+
+    Args:
+        x (Tensor): The input tensor. The data type is float32 or float64.
+        p (float|int): Probability of setting units to zero. Default 0.5.
+        axis (int|list|tuple): The axis along which the dropout is performed. Default None.
+        rng_name (str): The random seed generator name, which used to obtain deterministic results.
+        training (bool): A flag indicating whether it is in train phrase or not. Default True.
+        mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
+
+                           1. upscale_in_train(default), upscale the output at training time
+
+                              - train: out = input * mask / ( 1.0 - dropout_prob )
+                              - inference: out = input
+
+                           2. downscale_in_infer, downscale the output at inference
+
+                              - train: out = input * mask
+                              - inference: out = input * (1.0 - dropout_prob)
+        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Tensor representing the dropout, has same shape and data type as `x` .
+
+
+    Examples:
+        We use ``p=0.5`` in the following description for simplicity.
+
+        1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
+
+        ..  code-block:: text
+
+            Let's see a simple case when x is a 2d tensor with shape 2*3:
+            [[1 2 3]
+             [4 5 6]]
+            we generate mask with the same shape as x, which is 2*3. The value of mask is
+            sampled from a Bernoulli distribution randomly. For example, we may get such mask:
+            [[0 1 0]
+             [1 0 1]]
+            So the output is obtained from elementwise multiply of x and mask:
+            [[0 2 0]
+             [4 0 6]]
+            Using default setting, i.e. ``mode='upscale_in_train'`` ,
+            if in training phase, the final upscale output is:
+            [[0 4 0 ]
+             [8 0 12]]
+            if in test phase, the output is the same as input:
+            [[1 2 3]
+             [4 5 6]]
+            we can also set ``mode='downscale_in_infer'`` , then
+            if in training phase, the final output is:
+            [[0 2 0]
+             [4 0 6]]
+            if in test phase, the scale output is:
+            [[0.5 1.  1.5]
+             [2.  2.5 3. ]]
+
+    """
+    if rng_name is None:
+        return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
+
+    if not isinstance(p, (float, int, Variable)):
+        raise TypeError("p argument should be a number(int|float) or Variable")
+
+    # fast return for p == 0
+    if isinstance(p, (int, float)) and p == 0: return x
+
+    assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
+    assert mode in ('downscale_in_infer', 'upscale_in_train'), \
+        ValueError(
+            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
+
+    assert axis is None, \
+        TypeError("unsupport axis when using random seed generator")
+
+    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer
+
+    # dygraph using tracker, doesn't need determinate seed
+    if _non_static_mode():
+        out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test',
+                                          not training, 'fix_seed', False,
+                                          'seed', 0, 'dropout_implementation',
+                                          mode)
+        return out
+
+    seed = determinate_seed(rng_name)
+
+    if isinstance(p, Variable) and not p.shape != [1]:
+        raise TypeError(
+            "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}"
+            .format(p.shape))
+
+    helper = LayerHelper('dropout', **locals())
+    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
+                             'dropout')
+
+    out = helper.create_variable_for_type_inference(dtype=x.dtype)
+    mask = helper.create_variable_for_type_inference(
+        dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
+
+    helper.append_op(type='dropout',
+                     inputs={
+                         'X': [x],
+                         'Seed': seed
+                     },
+                     outputs={
+                         'Out': [out],
+                         'Mask': [mask]
+                     },
+                     attrs={
+                         'dropout_prob': p,
+                         'is_test': not training,
+                         'dropout_implementation': mode,
+                     })
+    return out
diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
index 6cb69bc73ce..66a1c877562 100644
--- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
+++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
@@ -1,4 +1,4 @@
-#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,298 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import paddle
-from paddle.fluid import core
-from paddle.fluid.dygraph.layers import Layer
-from .random import get_rng_state_tracker
-from paddle.nn import functional as F
-from paddle import framework
-from ...base import topology as tp
-from paddle.autograd import PyLayer
+from ...layers.mpu.mp_layers import VocabParallelEmbedding  # noqa: F401
+from ...layers.mpu.mp_layers import ColumnParallelLinear  # noqa: F401
+from ...layers.mpu.mp_layers import RowParallelLinear  # noqa: F401
+from ...layers.mpu.mp_layers import ParallelCrossEntropy  # noqa: F401
 
 __all__ = []
-
-# Follow this paper to achieve the file:
-# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
-# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
-
-
-def is_fused_matmul_bias_supported():
-    if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
-        return hasattr(core.ops, 'fused_gemm_epilogue')
-    else:
-        return False
-
-
-class VocabParallelEmbedding(Layer):
-
-    def __init__(self,
-                 num_embeddings,
-                 embedding_dim,
-                 weight_attr=None,
-                 mp_group=None,
-                 name=None):
-        super(VocabParallelEmbedding, self).__init__()
-
-        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
-        ) if mp_group is None else mp_group
-        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
-        ) if mp_group is None else mp_group.nranks
-        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
-        ) if mp_group is None else mp_group.rank
-
-        self.origin_num_embeddings = num_embeddings
-        self.is_mp = (self.world_size > 1)
-
-        assert num_embeddings % self.world_size == 0, (
-            "The length of the vocabulary must be divisible by the parallelism degree of MP"
-        )
-
-        per_part_size = num_embeddings // self.world_size
-
-        self.vocab_start_index = self.rank * per_part_size
-        self._dtype = self._helper.get_default_dtype()
-        self._size = [per_part_size, embedding_dim]
-        self._weight_attr = weight_attr
-        self._name = name
-
-        if self.is_mp and paddle.in_dynamic_mode():
-            with get_rng_state_tracker().rng_state():
-                self.weight = self.create_parameter(attr=self._weight_attr,
-                                                    shape=self._size,
-                                                    dtype=self._dtype,
-                                                    is_bias=False)
-        else:
-            self.weight = self.create_parameter(attr=self._weight_attr,
-                                                shape=self._size,
-                                                dtype=self._dtype,
-                                                is_bias=False)
-
-        self.weight.is_distributed = True if self.is_mp else False
-
-    def forward(self, x):
-        if self.is_mp:
-            output_parallel = paddle.distributed.collective._c_lookup_table(
-                self.weight,
-                x,
-                start_index=self.vocab_start_index,
-                name=self._name)
-            output = paddle.distributed.collective._mp_allreduce(
-                output_parallel,
-                group=self.model_parallel_group,
-                use_calc_stream=True,
-                use_model_parallel=True)
-        else:
-            output = F.embedding(x,
-                                 weight=self.weight,
-                                 padding_idx=None,
-                                 sparse=False,
-                                 name=self._name)
-        return output
-
-
-class ColumnParallelLinear(Layer):
-
-    def __init__(self,
-                 in_features,
-                 out_features,
-                 weight_attr=None,
-                 has_bias=None,
-                 gather_output=True,
-                 fuse_matmul_bias=False,
-                 mp_group=None,
-                 name=None):
-        super(ColumnParallelLinear, self).__init__()
-
-        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
-        ) if mp_group is None else mp_group
-        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
-        ) if mp_group is None else mp_group.nranks
-        self._name = name
-        self.is_mp = (self.world_size > 1)
-
-        self.gather_output = gather_output
-        assert out_features % self.world_size == 0, (
-            "Number of column of the weight for linear ({}) must be"
-            " divisible by model parallel size ({})".format(
-                out_features, self.world_size))
-        self.output_size_per_partition = out_features // self.world_size
-
-        self._weight_attr = weight_attr
-        self._dtype = self._helper.get_default_dtype()
-
-        if self.is_mp and paddle.in_dynamic_mode():
-            with get_rng_state_tracker().rng_state():
-                self.weight = self.create_parameter(
-                    shape=[in_features, self.output_size_per_partition],
-                    attr=self._weight_attr,
-                    dtype=self._dtype,
-                    is_bias=False)
-        else:
-            self.weight = self.create_parameter(
-                shape=[in_features, self.output_size_per_partition],
-                attr=self._weight_attr,
-                dtype=self._dtype,
-                is_bias=False)
-
-        self.weight.is_distributed = True if self.is_mp else False
-
-        if has_bias:
-            # initialize bias to zero like Megatron
-            self.bias = self.create_parameter(
-                shape=[self.output_size_per_partition],
-                attr=paddle.nn.initializer.Constant(value=0.0),
-                dtype=self._dtype,
-                is_bias=True)
-            self.bias.is_distributed = True if self.is_mp else False
-        else:
-            self.bias = None
-
-        self.linear = F.linear
-
-        if fuse_matmul_bias:
-            if not is_fused_matmul_bias_supported():
-                raise NotImplementedError(
-                    "You set fuse_matmul_bias=True in ColumnParallelLinear, "
-                    "however, the paddle you are using not support this operation. "
-                    "Please set fuse_matmul_bias=False or use paddle compiled "
-                    "with cuda 11.6 or higher.")
-            from paddle.incubate.nn.functional import fused_linear
-            self.linear = fused_linear
-
-    def forward(self, x):
-        # use inner api to process identity
-        if self.is_mp:
-            input_parallel = paddle.distributed.collective._c_identity(
-                x, group=self.model_parallel_group)
-        else:
-            input_parallel = x
-
-        output_parallel = self.linear(input_parallel,
-                                      self.weight,
-                                      self.bias,
-                                      name=self._name)
-
-        if self.gather_output and self.is_mp:
-            output = paddle.distributed.collective._c_concat(
-                output_parallel, group=self.model_parallel_group)
-        else:
-            output = output_parallel
-        return output
-
-
-class RowParallelLinear(Layer):
-
-    def __init__(self,
-                 in_features,
-                 out_features,
-                 weight_attr=None,
-                 has_bias=True,
-                 input_is_parallel=False,
-                 fuse_matmul_bias=False,
-                 mp_group=None,
-                 name=None):
-        super(RowParallelLinear, self).__init__()
-
-        self.in_features = in_features
-        self.out_features = out_features
-        self.input_is_parallel = input_is_parallel
-        self._weight_attr = weight_attr
-        self._dtype = self._helper.get_default_dtype()
-        self._name = name
-
-        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
-        ) if mp_group is None else mp_group
-        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
-        ) if mp_group is None else mp_group.nranks
-        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
-        ) if mp_group is None else mp_group.rank
-
-        self.is_mp = (self.world_size > 1)
-        assert in_features % self.world_size == 0, (
-            "Number of row of the weight for linear ({}) must be"
-            " divisible by model parallel size ({})".format(
-                in_features, self.world_size))
-
-        self.input_size_per_partition = in_features // self.world_size
-
-        if self.is_mp and paddle.in_dynamic_mode():
-            with get_rng_state_tracker().rng_state():
-                self.weight = self.create_parameter(
-                    shape=[self.input_size_per_partition, self.out_features],
-                    attr=self._weight_attr,
-                    dtype=self._dtype,
-                    is_bias=False)
-        else:
-            self.weight = self.create_parameter(
-                shape=[self.input_size_per_partition, self.out_features],
-                attr=self._weight_attr,
-                dtype=self._dtype,
-                is_bias=False)
-
-        self.weight.is_distributed = True if self.is_mp else False
-
-        if has_bias:
-            self.bias = self.create_parameter(
-                shape=[self.out_features],
-                attr=paddle.nn.initializer.Constant(value=0.0),
-                dtype=self._dtype,
-                is_bias=True)
-        else:
-            self.bias = None
-
-        self.linear = F.linear
-
-        if fuse_matmul_bias:
-            if not is_fused_matmul_bias_supported():
-                raise NotImplementedError(
-                    "You set fuse_matmul_bias=True in RowParallelLinear, "
-                    "however, the paddle you are using not support this operation. "
-                    "Please set fuse_matmul_bias=False or use paddle compiled "
-                    "with cuda 11.6 or higher.")
-            from paddle.incubate.nn.functional import fused_linear
-            self.linear = fused_linear
-
-    def forward(self, x):
-        if self.input_is_parallel or (not self.is_mp):
-            input_parallel = x
-        else:
-            # split last dim
-            input_parallel = paddle.distributed.collective._c_split(
-                x, group=self.model_parallel_group)
-
-        if self.is_mp:
-            output_parallel = self.linear(input_parallel,
-                                          self.weight,
-                                          name=self._name)
-            output_ = paddle.distributed.collective._mp_allreduce(
-                output_parallel,
-                group=self.model_parallel_group,
-                use_calc_stream=True,
-                use_model_parallel=True)
-            output = output_ + self.bias if self.bias is not None else output_
-        else:
-            output = self.linear(input_parallel,
-                                 self.weight,
-                                 self.bias,
-                                 name=self._name)
-
-        return output
-
-
-class ParallelCrossEntropy(Layer):
-
-    def __init__(self, mp_group=None, name=None):
-        super(ParallelCrossEntropy, self).__init__()
-        self.name = name
-        self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
-        ) if mp_group is None else mp_group
-        self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
-        ) if mp_group is None else mp_group.nranks
-        self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
-        ) if mp_group is None else mp_group.rank
-
-    def forward(self, input, label):
-        loss = paddle.distributed.collective._c_softmax_with_cross_entropy(
-            input, label, group=self.model_parallel_group)
-        return loss
diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
index 900c0f79798..9deed30db66 100644
--- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
+++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
@@ -1,4 +1,4 @@
-#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,232 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import paddle
-import contextlib
-import numpy as np
-from paddle import _C_ops, _legacy_C_ops
-from paddle.fluid import core
-from paddle.fluid.data_feeder import check_variable_and_dtype
-from paddle.fluid.framework import _non_static_mode, default_main_program, Variable
-from paddle.fluid.layer_helper import LayerHelper
+from ...layers.mpu.random import RNGStatesTracker  # noqa: F401
+from ...layers.mpu.random import get_rng_state_tracker  # noqa: F401
+from ...layers.mpu.random import model_parallel_random_seed  # noqa: F401
+from ...layers.mpu.random import determinate_seed  # noqa: F401
+from ...layers.mpu.random import dropout  # noqa: F401
 
 __all__ = []
-
-MODEL_PARALLEL_RNG = 'model_parallel_rng'
-
-# This file is inspired by Megatron to control random states for MP:
-# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py
-
-
-class RNGStatesTracker:
-    """
-    Tracker the RNG states.
-    """
-
-    def __init__(self):
-        # Map from name to the rng state.
-        self.states_ = {}
-        self.seeds_ = set()
-
-    def reset(self):
-        self.states_ = {}
-        self.seeds_ = set()
-
-    def add(self, name, seed):
-        if seed in self.seeds_:
-            raise ValueError('seed {} already exists'.format(seed))
-        self.seeds_.add(seed)
-        if name in self.states_:
-            raise ValueError('state {} already exists'.format(name))
-        orig_rng_state = paddle.get_cuda_rng_state()
-        paddle.seed(seed)
-        self.states_[name] = paddle.get_cuda_rng_state()
-        paddle.set_cuda_rng_state(orig_rng_state)
-
-    def get_states_tracker(self):
-        states = {}
-        for name in self.states_:
-            states[name] = self.states_[name]
-        return states
-
-    def set_states_tracker(self, states):
-        self.states_ = states
-
-    @contextlib.contextmanager
-    def rng_state(self, name=MODEL_PARALLEL_RNG):
-        if name not in self.states_:
-            raise ValueError('state {} does not exist'.format(name))
-        orig_cuda_rng_state = paddle.get_cuda_rng_state()
-        paddle.set_cuda_rng_state(self.states_[name])
-        try:
-            yield
-        finally:
-            self.states_[name] = paddle.get_cuda_rng_state()
-            paddle.set_cuda_rng_state(orig_cuda_rng_state)
-
-
-RNG_STATE_TRACKER = RNGStatesTracker()
-
-
-def get_rng_state_tracker():
-    return RNG_STATE_TRACKER
-
-
-def model_parallel_random_seed(seed=None):
-    import paddle.distributed.fleet as fleet
-    hcg = fleet.get_hybrid_communicate_group()
-    rank = hcg.get_model_parallel_rank()
-
-    if seed:
-        global_seed = seed
-        local_seed = seed * 1024 + rank * 100
-    else:
-        global_seed = np.random.randint(0, 655350)
-        local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
-
-    RNG_STATE_TRACKER.reset()
-    RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
-    paddle.seed(global_seed)
-
-
-def determinate_seed(rng_name):
-    assert rng_name is not None and rng_name != ""
-    helper = LayerHelper('seed', **locals())
-    out = helper.create_variable_for_type_inference(dtype=paddle.int32)
-    # set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
-    helper.append_op(type='seed',
-                     outputs={'Out': out},
-                     attrs={
-                         'deterministic': True,
-                         'rng_name': rng_name,
-                         'force_cpu': True
-                     })
-    return out
-
-
-def dropout(x,
-            p=0.5,
-            axis=None,
-            rng_name=None,
-            training=True,
-            mode="upscale_in_train",
-            name=None):
-    """
-    Dropout is a regularization technique for reducing overfitting by preventing
-    neuron co-adaption during training. The dropout operator randomly sets the
-    outputs of some units to zero, while upscale others according to the given
-    dropout probability.
-
-    Args:
-        x (Tensor): The input tensor. The data type is float32 or float64.
-        p (float|int): Probability of setting units to zero. Default 0.5.
-        axis (int|list|tuple): The axis along which the dropout is performed. Default None.
-        rng_name (str): The random seed generator name, which used to obtain deterministic results.
-        training (bool): A flag indicating whether it is in train phrase or not. Default True.
-        mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
-
-                           1. upscale_in_train(default), upscale the output at training time
-
-                              - train: out = input * mask / ( 1.0 - dropout_prob )
-                              - inference: out = input
-
-                           2. downscale_in_infer, downscale the output at inference
-
-                              - train: out = input * mask
-                              - inference: out = input * (1.0 - dropout_prob)
-        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        A Tensor representing the dropout, has same shape and data type as `x` .
-
-
-    Examples:
-        We use ``p=0.5`` in the following description for simplicity.
-
-        1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
-
-        ..  code-block:: text
-
-            Let's see a simple case when x is a 2d tensor with shape 2*3:
-            [[1 2 3]
-             [4 5 6]]
-            we generate mask with the same shape as x, which is 2*3. The value of mask is
-            sampled from a Bernoulli distribution randomly. For example, we may get such mask:
-            [[0 1 0]
-             [1 0 1]]
-            So the output is obtained from elementwise multiply of x and mask:
-            [[0 2 0]
-             [4 0 6]]
-            Using default setting, i.e. ``mode='upscale_in_train'`` ,
-            if in training phase, the final upscale output is:
-            [[0 4 0 ]
-             [8 0 12]]
-            if in test phase, the output is the same as input:
-            [[1 2 3]
-             [4 5 6]]
-            we can also set ``mode='downscale_in_infer'`` , then
-            if in training phase, the final output is:
-            [[0 2 0]
-             [4 0 6]]
-            if in test phase, the scale output is:
-            [[0.5 1.  1.5]
-             [2.  2.5 3. ]]
-
-    """
-    if rng_name is None:
-        return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
-
-    if not isinstance(p, (float, int, Variable)):
-        raise TypeError("p argument should be a number(int|float) or Variable")
-
-    # fast return for p == 0
-    if isinstance(p, (int, float)) and p == 0: return x
-
-    assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
-    assert mode in ('downscale_in_infer', 'upscale_in_train'), \
-        ValueError(
-            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
-
-    assert axis is None, \
-        TypeError("unsupport axis when using random seed generator")
-
-    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer
-
-    # dygraph using tracker, doesn't need determinate seed
-    if _non_static_mode():
-        out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test',
-                                          not training, 'fix_seed', False,
-                                          'seed', 0, 'dropout_implementation',
-                                          mode)
-        return out
-
-    seed = determinate_seed(rng_name)
-
-    if isinstance(p, Variable) and not p.shape != [1]:
-        raise TypeError(
-            "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}"
-            .format(p.shape))
-
-    helper = LayerHelper('dropout', **locals())
-    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
-                             'dropout')
-
-    out = helper.create_variable_for_type_inference(dtype=x.dtype)
-    mask = helper.create_variable_for_type_inference(
-        dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
-
-    helper.append_op(type='dropout',
-                     inputs={
-                         'X': [x],
-                         'Seed': seed
-                     },
-                     outputs={
-                         'Out': [out],
-                         'Mask': [mask]
-                     },
-                     attrs={
-                         'dropout_prob': p,
-                         'is_test': not training,
-                         'dropout_implementation': mode,
-                     })
-    return out
diff --git a/python/setup.py.in b/python/setup.py.in
index 3d400881de3..04ad7bf0388 100755
--- a/python/setup.py.in
+++ b/python/setup.py.in
@@ -307,6 +307,8 @@ packages=['paddle',
           'paddle.distributed.fleet.metrics',
           'paddle.distributed.fleet.proto',
           'paddle.distributed.fleet.utils',
+          'paddle.distributed.fleet.layers',
+          'paddle.distributed.fleet.layers.mpu',
           'paddle.distributed.fleet.meta_parallel',
           'paddle.distributed.fleet.meta_parallel.pp_utils',
           'paddle.distributed.fleet.meta_parallel.sharding',
-- 
GitLab