moe_utils.py 12.0 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _legacy_C_ops
16 17 18
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.layer_helper import LayerHelper
R
Roc 已提交
19 20


21 22 23
def global_scatter(
    x, local_count, global_count, group=None, use_calc_stream=True
):
R
Roc 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    """
    The global_scatter operator distributes the data of x to n_expert * world_size experts according to local_count,
    and then receives data according to global_count. The expert refers to a user-defined expert network,
    n_expert refers to the number of expert networks owned by each card, and world_size refers to the number of graphics cards running the network.

    As shown below, the value of the world size is 2, n_expert 2, the batch size of the x 4 and local_count is [2, 0, 2, 0].
    The global_count of the rank 0 is [2, 0, , ], rank 1 is [2, 0, ,](Due to the limited space, only the data calculated on rank 0 is shown here).
    In the global_scatter operator, local_count[i] represents sending local_count[i] data to the (i % n_expert)th expert of the (i // n_expert)th card,
    global_count[i] represents receiving global_count[i] data from the (i // n_expert)th card to the (i % n_expert)th expert of this card. The rank in the
    figure respresent the rank of the current card in all cards.

    The process of global_scatter sending data is as follows:

    local_count[0] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 0th card;

    local_count[1] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 0th card;

    local_count[2] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 1th card;

    local_count[3] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 1th card;

    Therefore, the global_count[0] of the 0th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;

    the global_count[1] of the 0th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert;

    the global_count[0] of the 1th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;

    the global_count[1] of the 1th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/global_scatter_gather.png
        :width: 800
        :alt: global_scatter_gather
        :align: center

    Args:
        x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
        local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
            how many data needed to be sent. The tensor data type should be int64.
        global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
            how many data needed to be received. The tensor data type should be int64.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.

    Returns:
        out (Tensor): The data received from all experts.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            from paddle.distributed import init_parallel_env
            init_parallel_env()
            n_expert = 2
            world_size = 2
            d_model = 2
            in_feat = d_model
81 82
            local_input_buf = paddle.to_tensor([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]], \
                                            dtype='float32', stop_gradient=False)
R
Roc 已提交
83
            if paddle.distributed.ParallelEnv().local_rank == 0:
84 85
                local_count = paddle.to_tensor([2, 1, 1, 1], dtype="int64")
                global_count = paddle.to_tensor([2, 1, 1, 1], dtype="int64")
R
Roc 已提交
86
            else:
87 88
                local_count = paddle.to_tensor([1, 1, 2, 1], dtype="int64")
                global_count = paddle.to_tensor([1, 1, 2, 1], dtype="int64")
R
Roc 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            a = paddle.distributed.utils.global_scatter(local_input_buf, \
            local_count, global_count)
            a.stop_gradient = False
            print(a)
            # out for rank 0: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]]
            # out for rank 1: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]]
            # backward test
            c = a * a
            c.backward()
            print("local_input_buf.grad: ", local_input_buf.grad)
            # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
            # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
    """
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
    if _non_static_mode():
107 108 109 110 111 112 113 114 115
        return _legacy_C_ops.global_scatter(
            x,
            local_count,
            global_count,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
        )
R
Roc 已提交
116 117 118
    else:
        op_type = 'global_scatter'
        check_variable_and_dtype(
119 120 121 122 123 124 125 126 127 128 129
            x,
            'x',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            'global_scatter',
        )
        check_variable_and_dtype(
            local_count, 'local_count', ['int64'], 'global_scatter'
        )
        check_variable_and_dtype(
            global_count, 'global_count', ['int64'], 'global_scatter'
        )
R
Roc 已提交
130 131 132 133

        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=x.dtype)

134 135 136 137 138 139 140 141 142 143
        helper.append_op(
            type=op_type,
            inputs={
                'X': [x],
                'local_count': [local_count],
                'global_count': [global_count],
            },
            outputs={'Out': [out]},
            attrs={'ring_id': ring_id, 'use_calc_stream': use_calc_stream},
        )
R
Roc 已提交
144 145 146
        return out


147 148 149
def global_gather(
    x, local_count, global_count, group=None, use_calc_stream=True
):
R
Roc 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    """
    The global_gather operator gathers the data of x into n_expert * world_size experts according to global_count, and then receives data according to local_count.
    The expert refers to a user-defined expert network, n_expert refers to the number of expert networks owned by each card, and world_size refers to the number of graphics cards running the network.

    As shown below, the value of the world size is 2, n_expert 2, the batch size of the x 4 and local_count is [2, 0, 2, 0].
    The global_count of the rank 0 is [2, 0, , ], rank 1 is [2, 0, ,](Due to the limited space, only the data calculated on rank 0 is shown here).
    In the global_gather operator, the meaning of the global_count and local_count is opposed to global_scatter, global_count[i] represents sending global_count[i] data to the (i % n_expert)th expert of the (i // n_expert)th card,
    local_count[i] represents receiving local_count[i] data from the (i // n_expert)th card to the (i % n_expert)th expert of this card. The data sent will be arranged according to the experts of each card.
    The rank in the figure respresent the rank of the current card in all cards.

    The process of global_gather sending data is as follows:

    The global_count[0] of the 0th card represents sending 2 data to the 0th expert of the 0th card;

    The global_count[1] of the 0th card represents sending 0 data to the 1th expert of the 0th card;

    The global_count[0] of the 1th card represents sending 2 data to the 0th expert of the 0th card;

    The global_count[1] of the 1th card represents sending 0 data to the 1th expert of the 0th card.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/global_scatter_gather.png
        :width: 800
        :alt: global_scatter_gather
        :align: center


    Args:
        x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
        local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
            how many data needed to be received. Tensor data type should be int64.
        global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
            how many data needed to be sent. Tensor data type should be int64.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.

    Returns:
        out (Tensor): The data received from all experts.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            from paddle.distributed import init_parallel_env
            init_parallel_env()
            n_expert = 2
            world_size = 2
            d_model = 2
            in_feat = d_model
199 200
            local_input_buf = paddle._to_tensor([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]],\
                                        dtype='float32', stop_gradient=False)
R
Roc 已提交
201
            if paddle.distributed.ParallelEnv().local_rank == 0:
202 203
                local_count = paddle.to_tensor([2, 1, 1, 1], dtype="int64")
                global_count = paddle.to_tensor([2, 1, 1, 1], dtype="int64")
R
Roc 已提交
204
            else:
205 206 207
                local_count = paddle.to_tensor([1, 1, 2, 1], dtype="int64")
                global_count = paddle.to_tensor([1, 1, 2, 1], dtype="int64")

R
Roc 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
            a = paddle.distributed.utils.global_gather(local_input_buf, local_count, global_count)
            print(a)
            # out for rank 0: [[1, 2], [3, 4], [7, 8], [1, 2], [7, 8]]
            # out for rank 1: [[5, 6], [9, 10], [3, 4], [5, 6], [9, 10]]
            a.stop_gradient = False
            c = a * a
            c.backward()
            print("local_input_buf.grad", local_input_buf.grad)
            # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
            # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
    """
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
    if _non_static_mode():
224 225 226 227 228 229 230 231 232
        return _legacy_C_ops.global_gather(
            x,
            local_count,
            global_count,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
        )
R
Roc 已提交
233 234 235
    else:
        op_type = 'global_gather'
        check_variable_and_dtype(
236 237 238 239 240
            x,
            'x',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            'global_gather',
        )
R
Roc 已提交
241

242 243 244
        check_variable_and_dtype(
            local_count, 'local_count', ['int64'], 'global_gather'
        )
R
Roc 已提交
245

246 247 248
        check_variable_and_dtype(
            global_count, 'global_count', ['int64'], 'global_gather'
        )
R
Roc 已提交
249 250 251
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=x.dtype)

252 253 254 255 256 257 258 259 260 261 262 263 264
        helper.append_op(
            type=op_type,
            inputs={
                'X': [x],
                'local_count': [local_count],
                'global_count': [global_count],
            },
            outputs={'Out': [out]},
            attrs={
                'ring_id': group,
                'use_calc_stream': use_calc_stream,
            },
        )
R
Roc 已提交
265
        return out