moe_layer.py 15.3 KB
Newer Older
R
Roc 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
R
Roc 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
R
Roc 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
R
Roc 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
R
Roc 已提交
14 15 16 17 18 19 20
#
# The file has been adapted from the file:
#     https://github.com/laekov/fastmoe/blob/master/fmoe/layers.py
#     Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
# We retain the following license from the original files:
#     Copyright 2021, Jiaao He. All rights reserved.
#   Licensed under the Apache License, Version 2.0 (the "License").
R
Roc 已提交
21 22

import numpy as np
23

R
Roc 已提交
24
import paddle
25
from paddle import nn
26
from paddle.autograd import PyLayer
27
from paddle.distributed.utils.moe_utils import global_gather, global_scatter
28
from paddle.distributed.utils.nccl_utils import check_nccl_version_for_p2p
W
wuhuachaocoding 已提交
29
from paddle.framework import in_dygraph_mode
30
from paddle.incubate.distributed.fleet import recompute_hybrid
R
Roc 已提交
31

32 33 34
from .gate import BaseGate, GShardGate, NaiveGate, SwitchGate
from .utils import count_by_gate

R
Roc 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47

def _local_scatter(inp, pos):
    if pos.shape != [0]:
        inp_buf = paddle.index_select(inp, pos, 0)
    else:
        inp_buf = paddle.empty([0, inp.shape[1]], dtype=inp.dtype)
    return inp_buf


def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
    if pos.shape != [0]:
        origin_dtype = inp.dtype
        inp = paddle.cast(inp, dtype="float32")
48 49 50 51 52 53 54 55
        inp_buf = paddle.scatter(
            paddle.zeros(
                shape=[out_batch_size, inp.shape[-1]], dtype="float32"
            ),
            pos,
            inp,
            overwrite=True,
        )
R
Roc 已提交
56 57 58 59 60 61 62 63 64
        inp_buf = paddle.cast(inp_buf, dtype=origin_dtype)
    else:
        inp_buf = paddle.zeros([out_batch_size, inp.shape[-1]], dtype=inp.dtype)
    return inp_buf


def _all_gather(tensor, group=None, use_calc_stream=True):
    if group is not None and not group.is_member():
        return
65 66

    if in_dygraph_mode():
67 68 69 70 71
        group = (
            paddle.distributed.collective._get_default_group()
            if group is None
            else group
        )
72 73 74 75 76 77 78 79 80
        tensor_shape = list(tensor.shape)
        tensor_shape[0] *= group.nranks
        out = paddle.empty(tensor_shape, tensor.dtype)

        task = group.process_group.all_gather(tensor, out)
        task.wait()
        return out
    else:
        ring_id = 0 if group is None else group.id
81 82 83 84 85 86 87 88 89 90 91 92 93 94
        nranks = (
            paddle.distributed.collective._get_global_group().nranks
            if group is None
            else group.nranks
        )
        return paddle._legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
R
Roc 已提交
95 96


R
Roc 已提交
97
class MoEScatter(PyLayer):
R
Roc 已提交
98 99 100 101 102 103 104
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """

    @staticmethod
105 106 107 108 109 110 111 112 113 114
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
        group=None,
    ):
R
Roc 已提交
115 116
        local_input_buf = _local_scatter(inp, pos)
        if world_size > 1:
117 118 119 120 121 122
            global_input_buf = global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                group=group,
            )
R
Roc 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        else:
            global_input_buf = local_input_buf

        ctx.moe_args = inp.shape[0], world_size, group

        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return global_input_buf

    @staticmethod
    def backward(ctx, grad):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensor()
        (inp_batch_size, world_size, group) = ctx.moe_args

        if world_size > 1:
138 139 140
            local_grad_in = global_gather(
                grad, local_expert_count, global_expert_count, group=group
            )
R
Roc 已提交
141 142 143 144 145 146
        else:
            local_grad_in = grad
        grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
        return grad_in, None, None, None


R
Roc 已提交
147
class MoEGather(PyLayer):
R
Roc 已提交
148 149
    r"""
    Gather output samples from contiguous alone experts back to [batch x
R
Roc 已提交
150
    sequences]. Works symmetrically with MoEScatter.
R
Roc 已提交
151 152 153
    """

    @staticmethod
154 155 156 157 158 159 160 161 162 163 164 165
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
        group=None,
    ):
        if world_size > 1:
            local_output_buf = global_gather(
R
Roc 已提交
166 167 168
                global_output_buf,
                local_expert_count,
                global_expert_count,
169 170
                group=group,
            )
R
Roc 已提交
171 172
        else:
            local_output_buf = global_output_buf
173 174 175
        output = _local_gather(
            local_output_buf, pos, local_batch_size, maybe_overlap=False
        )
R
Roc 已提交
176 177 178 179 180 181 182 183 184 185 186 187

        ctx.moe_args = (global_output_buf.shape[0], world_size, group)
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensor()
        fwd_batch_size, world_size, group = ctx.moe_args
        grad_out_buf = _local_scatter(grad_out, pos)
        if world_size > 1:
188 189 190 191 192 193
            global_grad_out_buf = global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                group=group,
            )
R
Roc 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None


class AllGather(PyLayer):
    r"""
    A wrapper for the All-Gather function to support auto-differentiation.
    """

    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = []
        paddle.distributed.all_gather(tensor_list, inp, group=group)
        output = paddle.concat(tensor_list, axis=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
215 216 217
        return paddle.slice(
            grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]
        )
R
Roc 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230


class Slice(PyLayer):
    r"""
    A wrapper for the Slice function to support auto-differentiation.
    """

    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
231 232 233
        inp = paddle.slice(
            inp, axes=[0], starts=[batch_start], ends=[batch_end]
        )
R
Roc 已提交
234 235 236 237 238 239 240
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        return _all_gather(grad_out, group=group)
241 242


R
Roc 已提交
243 244
def prepare_forward(gate, num_expert, world_size, moe_group):
    pos, local_expert_count, global_expert_count = count_by_gate(
245 246
        gate, num_expert, world_size, group=moe_group
    )
R
Roc 已提交
247 248
    with paddle.no_grad():
        fwd_expert_count = global_expert_count.reshape_(
249 250
            [world_size, num_expert]
        ).sum(axis=0)
R
Roc 已提交
251 252 253 254 255 256
        fwd_batch_size = int(fwd_expert_count.sum().item())
    return (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
257 258
        fwd_batch_size,
    )
R
Roc 已提交
259 260


R
Roc 已提交
261 262
class MoELayer(nn.Layer):
    """MoE Layer
R
Roc 已提交
263 264 265
    Args:
        d_model: (int) model dimention
        experts: (nn.LayerList) expert networks list
266
        gate: (dict|NaiveGate|SwitchGate|NaiveGate):
R
Roc 已提交
267
                if gate is a dict:
268
                    gate is a gate network config, containing 2 keys:
R
Roc 已提交
269 270 271 272 273 274
                    `type`(str) value can be: "naive", "gshard", "switch" or None, default is "gshard"
                    `top_k`(int) default value is 2
                else gate is an instance of NaiveGate|SwitchGate|NaiveGate:

        moe_group: moe group for experts communication
        mp_group: mp group for mp commutication
275 276
        recompute_interval(int, optional): whether to use recompute, default 0, means to disable recompute.
        recompute_ctx(dict, optional): the context for recompute, if recompute_interval > 1, recompute_ctx must be given.
R
Roc 已提交
277 278 279
    Examples:
        .. code-block:: python
        from paddle.nn import layer, LayerList
R
Roc 已提交
280
        from paddle.distributed.moe import MoElayer
R
Roc 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
        from paddle.distributed.collective import Group
        from paddle.distributed import fleet

        moe_group = Group(fleet.worker_index(),
                          0,
                          list(range(fleet.worker_num())))
        mp_group = None

        num_experts=8
        dim_feedforward=512
        d_model=8
        top_k=2

        class ExpertLayer(Layer):
            def __init__(self, d_model, d_hidden, name=None,rank=0, windex = 0, num_expert=1):
296
                super().__init__()
R
Roc 已提交
297 298 299 300 301 302 303 304 305 306 307 308
                self.htoh4 = nn.Linear(d_model, d_hidden)
                self.h4toh = nn.Linear(d_hidden, d_model)

            def forward(self, x):
                x = self.htoh4(x)
                x = self.h4toh(x)
                return x

        gate_config = {
                "type": "gshard",
                "top_k": top_k,
        }
309

R
Roc 已提交
310 311 312 313
        experts_list = LayerList()
        for expi in range(num_experts):
            exp_layer = ExpertLayer(d_model, dim_feedforward // top_k, windex=expi, num_expert=num_experts)
            experts_list.append(exp_layer)
314

R
Roc 已提交
315
        moeLayer = MoELayer(d_model = d_model,
R
Roc 已提交
316 317 318 319 320
                            experts=experts_list,
                            gate=gate_config,
                            moe_group=moe_group,
                            mp_group=mp_group,
                            recompute_interval=0)
321

R
Roc 已提交
322 323
    """

324 325 326 327 328 329 330 331 332 333
    def __init__(
        self,
        d_model,
        experts,
        gate=None,
        moe_group=None,
        mp_group=None,
        recompute_interval=0,
        recompute_ctx=None,
    ):
334
        super().__init__()
R
Roc 已提交
335

336
        self.recompute_ctx = recompute_ctx
R
Roc 已提交
337 338

        if gate is None:
339
            gate = {}
R
Roc 已提交
340

341 342 343
        assert isinstance(
            gate, (dict, BaseGate)
        ), "gate config' type must be dict or an instance of BaseGate"
R
Roc 已提交
344 345 346 347 348 349 350 351 352 353 354
        # only support mp/dp
        self.group = moe_group

        self.world_size = 1
        if self.group is not None:
            self.world_size = self.group.nranks
        self.num_expert = len(experts)
        self.recompute_interval = recompute_interval
        assert experts is not None
        self.experts = experts

355 356 357
        if self.world_size > 1:
            check_nccl_version_for_p2p()

R
Roc 已提交
358 359 360 361 362 363
        self.mp_group = mp_group
        self.d_model = d_model
        if isinstance(gate, dict):
            self.top_k = gate.get("top_k", 2)
            gate = gate.get("type", "gshard")
            if gate == "naive" or gate is None:
364 365 366 367 368 369
                gate = NaiveGate(
                    self.d_model,
                    num_expert=len(experts),
                    world_size=self.world_size,
                    topk=self.top_k,
                )
R
Roc 已提交
370
            elif gate == "gshard":
371 372 373 374 375 376 377
                gate = GShardGate(
                    self.d_model,
                    num_expert=len(experts),
                    world_size=self.world_size,
                    topk=self.top_k,
                    group=self.group,
                )
R
Roc 已提交
378
            elif gate == "switch":
379 380 381 382 383 384 385
                gate = SwitchGate(
                    self.d_model,
                    num_expert=len(experts),
                    world_size=self.world_size,
                    topk=self.top_k,
                    group=self.group,
                )
R
Roc 已提交
386
            else:
387 388 389 390
                raise AssertionError(
                    "We only support naive gate,                                 gshard gate and switch gate,                                 but you choose {} gate.".format(
                        str(gate)
                    )
391
                )
R
Roc 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
        elif isinstance(gate, NaiveGate):
            self.top_k = gate.top_k
        elif isinstance(gate, BaseGate):
            raise TypeError("Unimplemented gate type: ", type(gate))
        else:
            raise TypeError("gate's type must be either dict or moe.BaseGate")
        self.gate = gate

    def forward(self, inp):
        # inp shape: b * s * m
        assert len(inp.shape) == 3
        origin_shape = inp.shape
        inp = inp.reshape_([-1, origin_shape[2]])

        mp_rank = 0
        mp_size = 1
        if self.mp_group is not None:
            mp_rank = self.mp_group.rank
            mp_size = self.mp_group.nranks
        if mp_size > 1:
412
            inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
R
Roc 已提交
413 414 415 416 417 418 419
        value, gate = self.gate(inp)

        (
            pos,
            local_expert_count,
            global_expert_count,
            fwd_expert_count,
420 421
            fwd_batch_size,
        ) = prepare_forward(gate, self.num_expert, self.world_size, self.group)
R
Roc 已提交
422 423 424 425 426 427 428 429 430 431 432

        topk = 1
        if len(gate.shape) == 2:
            topk = gate.shape[1]

        if pos.shape != [0]:
            temp_pos = pos // topk
        else:
            temp_pos = pos
        assert topk == self.top_k

433 434 435 436 437 438 439 440 441
        x = MoEScatter.apply(
            inp,
            temp_pos,
            local_expert_count,
            global_expert_count,
            fwd_batch_size,
            self.world_size,
            self.group,
        )
R
Roc 已提交
442 443 444 445 446 447

        d_model = self.d_model

        def experts_fwd(x, fwd_expert_count, experts):

            if x.shape[0] == 0:
R
Roc 已提交
448
                return x
R
Roc 已提交
449 450 451 452 453 454 455
            y = []
            last_index = 0
            assert isinstance(fwd_expert_count, np.ndarray)
            assert len(experts) == len(fwd_expert_count)
            for idx, expert_count in enumerate(fwd_expert_count):
                if expert_count <= 0:
                    continue
456 457 458
                y.append(
                    experts[idx](x[last_index : expert_count + last_index])
                )
R
Roc 已提交
459 460 461
                last_index = expert_count + last_index
            return paddle.concat(y, axis=0)

R
Roc 已提交
462
        if self.recompute_interval <= 0 or x.shape[0] == 0:
R
Roc 已提交
463 464
            x = experts_fwd(x, fwd_expert_count.numpy(), self.experts)
        else:
465 466 467 468 469 470 471
            x = recompute_hybrid(
                self.recompute_ctx,
                experts_fwd,
                x,
                fwd_expert_count.numpy(),
                self.experts,
            )
R
Roc 已提交
472 473 474 475 476

        out_batch_size = inp.shape[0]
        if len(gate.shape) == 2:
            out_batch_size *= gate.shape[1]

477 478 479 480 481 482 483 484 485
        x = MoEGather.apply(
            x,
            pos,
            local_expert_count,
            global_expert_count,
            out_batch_size,
            self.world_size,
            self.group,
        )
R
Roc 已提交
486 487 488 489 490 491

        x = x.reshape([-1, self.top_k, d_model])
        value = value.reshape([x.shape[0], 1, self.top_k])
        x = paddle.bmm(value, x).reshape([-1, d_model])

        if mp_size > 1:
492
            x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
R
Roc 已提交
493 494 495 496

        x = paddle.reshape_(x, origin_shape)

        return x