auto_parallel_gpt_model.py 35.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
21
from paddle.distributed.fleet import auto
22 23 24 25 26 27 28 29
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list

paddle.enable_static()


def init_global():
    global _global_parallel_strategy
30
    _global_parallel_strategy = None
31 32 33 34 35 36 37 38 39 40 41 42 43
    global _global_process_mesh
    global PP_MESH_LIST
    global DPPP_MESH_LIST
    global MPPP_MESH_LIST
    global DPMPPP_MESH_LIST


class MultiHeadAttention(nn.Layer):
    """
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    """
44

45 46 47
    Cache = collections.namedtuple("Cache", ["k", "v"])
    StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

48 49 50 51 52 53 54 55 56 57 58 59 60
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        kdim=None,
        vdim=None,
        need_weights=False,
        weight_attr=None,
        bias_attr=None,
        fuse=False,
        mesh_idx=None,
    ):
61
        super().__init__()
62 63 64 65 66 67 68 69 70
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.need_weights = need_weights
        self.fuse = fuse
        self.mesh_idx = mesh_idx
        self.head_dim = embed_dim // num_heads
71 72 73
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
74 75 76
        if self.fuse:
            assert self.kdim == embed_dim
            assert self.vdim == embed_dim
77 78 79
            self.qkv_proj = nn.Linear(
                embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr
            )
80
        else:
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            self.q_proj = nn.Linear(
                embed_dim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.k_proj = nn.Linear(
                self.kdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.v_proj = nn.Linear(
                self.vdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr
        )
102 103 104

    def _fuse_prepare_qkv(self, query):
        mix_layer = self.qkv_proj(query)
105 106 107
        mix_layer = paddle.reshape_(
            mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]
        )
108 109 110 111 112 113 114 115 116 117 118 119
        mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
        q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
        return q, k, v

    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        """
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.
        """
        q = self.q_proj(query)
        if _global_parallel_strategy == "mp":
120 121 122
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "x"]
            )
123
        elif _global_parallel_strategy == "dp_mp":
124 125 126
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "y"]
            )
127
        elif _global_parallel_strategy == "mp_pp":
128 129 130
            auto.shard_tensor(
                self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
131
        elif _global_parallel_strategy == "dp_mp_pp":
132 133 134
            auto.shard_tensor(
                self.q_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
135

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)
        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)
        return (q, k, v) if use_cache is False else (q, k, v, cache)

    def compute_kv(self, key, value):
        """
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.
        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.
        """
        k = self.k_proj(key)
        if _global_parallel_strategy == "mp":
163 164 165
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "x"]
            )
166
        elif _global_parallel_strategy == "dp_mp":
167 168 169
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "y"]
            )
170
        elif _global_parallel_strategy == "mp_pp":
171 172 173
            auto.shard_tensor(
                self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
174
        elif _global_parallel_strategy == "dp_mp_pp":
175 176 177
            auto.shard_tensor(
                self.k_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
178 179
        v = self.v_proj(value)
        if _global_parallel_strategy == "mp":
180 181 182
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "x"]
            )
183
        elif _global_parallel_strategy == "dp_mp":
184 185 186
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "y"]
            )
187
        elif _global_parallel_strategy == "mp_pp":
188 189 190
            auto.shard_tensor(
                self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
191
        elif _global_parallel_strategy == "dp_mp_pp":
192 193 194
            auto.shard_tensor(
                self.v_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v

    def gen_cache(self, key, value=None, type=Cache):
        """
        Generates cache for `forward` usage in inference accroding to arguments.
        The generated cache is an instance of `MultiHeadAttention.Cache` or an
        instance of `MultiHeadAttention.StaticCache`.
        """
        if type == MultiHeadAttention.StaticCache:  # static_kv
            k, v = self.compute_kv(key, value)
            return self.StaticCache(k, v)
        elif value is None:  # incremental_state
            k = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
215 216
                value=0,
            )
217 218 219 220
            v = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
221 222
                value=0,
            )
223 224 225 226 227
            return self.Cache(k, v)
        else:
            # incremental_state with initial value, mainly for usage like UniLM
            return self.Cache(key, value)

228 229 230
    def forward(
        self, query, key, value, attn_mask=None, use_cache=False, cache=None
    ):
231 232 233 234 235 236 237 238 239 240 241 242 243
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
        else:
244 245 246 247 248 249
            q, k, v, cache = self._prepare_qkv(
                query, key, value, use_cache, cache
            )
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
250 251 252 253
        if attn_mask is not None:
            product = product + attn_mask
        weights = F.softmax(product)
        if self.dropout:
254 255 256 257 258 259
            weights = F.dropout(
                weights,
                self.dropout,
                training=self.training,
                mode="upscale_in_train",
            )
260 261 262 263 264 265 266
        out = tensor.matmul(weights, v)
        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        # project to output
        out = self.out_proj(out)
        if _global_parallel_strategy == "mp":
267 268 269
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["x", None]
            )
270
        elif _global_parallel_strategy == "dp_mp":
271 272 273
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["y", None]
            )
274
        elif _global_parallel_strategy == "mp_pp":
275 276 277
            auto.shard_tensor(
                self.out_proj.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
278
        elif _global_parallel_strategy == "dp_mp_pp":
279 280 281 282 283
            auto.shard_tensor(
                self.out_proj.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
284

285 286 287 288 289 290 291 292 293 294 295 296 297 298
        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)


class TransformerDecoder(nn.Layer):
    """
    TransformerDecoder is a stack of N decoder layers.
    """

    def __init__(self, decoder_layers, num_layers, norm=None, hidden_size=None):
299
        super().__init__()
300 301 302 303

        self.num_layers = num_layers
        self.layers = decoder_layers
        self.norm = norm
304
        if norm == "LayerNorm":
305 306 307 308 309
            self.norm = nn.LayerNorm(hidden_size)
        elif norm is not None:
            raise ValueError("Only support LayerNorm")
        self.checkpoints = []

310 311 312 313 314 315 316 317 318
    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        use_cache=False,
        cache=None,
    ):
319 320 321 322 323 324 325 326 327
        """
        Applies a stack of N Transformer decoder layers on inputs. If `norm` is
        provided, also applies layer normalization on the output of last decoder
        layer.
        """
        output = tgt
        new_caches = []
        self.checkpoints = []
        if _global_parallel_strategy == "pp":
328 329 330 331 332
            auto.shard_tensor(
                output,
                PP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
333
        if _global_parallel_strategy == "dp_pp":
334 335 336 337 338
            auto.shard_tensor(
                output,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
339
        if _global_parallel_strategy == "mp_pp":
340 341 342 343 344
            auto.shard_tensor(
                output,
                MPPP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
345
        if _global_parallel_strategy == "dp_mp_pp":
346 347 348 349 350
            auto.shard_tensor(
                output,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
351 352 353 354 355
        for i, mod in enumerate(self.layers):
            if cache is None:
                if use_cache:
                    if _global_parallel_strategy == "pp":
                        output, new_cache = auto.shard_op(
356 357
                            mod, PP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
358
                        auto.shard_tensor(
359 360 361 362
                            output,
                            PP_MESH_LIST[mod.mesh_idx],
                            [None for i in range(len(output.shape))],
                        )
363 364
                    elif _global_parallel_strategy == "dp_pp":
                        output, new_cache = auto.shard_op(
365 366
                            mod, DPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
367
                        auto.shard_tensor(
368 369 370 371 372
                            output,
                            DPPP_MESH_LIST[mod.mesh_idx],
                            ["x"]
                            + [None for i in range(len(output.shape) - 1)],
                        )
373 374
                    elif _global_parallel_strategy == "mp_pp":
                        output, new_cache = auto.shard_op(
375 376
                            mod, MPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
377
                        auto.shard_tensor(
378 379 380 381
                            output,
                            MPPP_MESH_LIST[mod.mesh_idx],
                            [None for i in range(len(output.shape))],
                        )
382 383
                    elif _global_parallel_strategy == "dp_mp_pp":
                        output, new_cache = auto.shard_op(
384 385
                            mod, DPMPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
386
                        auto.shard_tensor(
387 388 389 390
                            output,
                            DPMPPP_MESH_LIST[mod.mesh_idx],
                            [None for i in range(len(output.shape))],
                        )
391
                    else:
392 393 394 395 396 397 398
                        output, new_cache = mod(
                            output,
                            memory,
                            tgt_mask=tgt_mask,
                            use_cache=use_cache,
                            cache=cache,
                        )
399 400 401
                    new_caches.append(new_cache)
                else:
                    if _global_parallel_strategy == "pp":
402
                        output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])(
403 404
                            output, memory, tgt_mask, use_cache, cache
                        )
405
                        auto.shard_tensor(
406 407 408 409
                            output,
                            PP_MESH_LIST[mod.mesh_idx],
                            [None for i in range(len(output.shape))],
                        )
410
                    elif _global_parallel_strategy == "dp_pp":
411
                        output = auto.shard_op(
412 413
                            mod, DPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
414
                        auto.shard_tensor(
415 416 417 418 419
                            output,
                            DPPP_MESH_LIST[mod.mesh_idx],
                            ["x"]
                            + [None for i in range(len(output.shape) - 1)],
                        )
420
                    elif _global_parallel_strategy == "mp_pp":
421
                        output = auto.shard_op(
422 423
                            mod, MPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
424
                        auto.shard_tensor(
425 426 427 428
                            output,
                            MPPP_MESH_LIST[mod.mesh_idx],
                            [None for i in range(len(output.shape))],
                        )
429
                    elif _global_parallel_strategy == "dp_mp_pp":
430 431 432
                        output = auto.shard_op(
                            mod, DPMPPP_MESH_LIST[mod.mesh_idx]
                        )(output, memory, tgt_mask, use_cache, cache)
433
                        auto.shard_tensor(
434 435 436 437 438
                            output,
                            DPMPPP_MESH_LIST[mod.mesh_idx],
                            ["x"]
                            + [None for i in range(len(output.shape) - 1)],
                        )
439
                    else:
440 441 442 443 444 445 446
                        output = mod(
                            output,
                            memory,
                            tgt_mask=tgt_mask,
                            use_cache=use_cache,
                            cache=cache,
                        )
447 448 449
            else:
                if _global_parallel_strategy == "pp":
                    output, new_cache = auto.shard_op(
450 451 452 453 454 455 456
                        mod, PP_MESH_LIST[mod.mesh_idx]
                    )(output, memory, tgt_mask, use_cache, cache)
                    auto.shard_tensor(
                        output,
                        PP_MESH_LIST[mod.mesh_idx],
                        [None for i in range(len(output.shape))],
                    )
457 458
                elif _global_parallel_strategy == "dp_pp":
                    output, new_cache = auto.shard_op(
459 460
                        mod, DPPP_MESH_LIST[mod.mesh_idx]
                    )(output, memory, tgt_mask, use_cache, cache)
461
                    auto.shard_tensor(
462 463 464 465
                        output,
                        DPPP_MESH_LIST[mod.mesh_idx],
                        ["x"] + [None for i in range(len(output.shape) - 1)],
                    )
466 467
                elif _global_parallel_strategy == "mp_pp":
                    output, new_cache = auto.shard_op(
468 469 470 471 472 473 474
                        mod, MPPP_MESH_LIST[mod.mesh_idx]
                    )(output, memory, tgt_mask, use_cache, cache)
                    auto.shard_tensor(
                        output,
                        MPPP_MESH_LIST[mod.mesh_idx],
                        [None for i in range(len(output.shape))],
                    )
475 476
                elif _global_parallel_strategy == "dp_mp_pp":
                    output, new_cache = auto.shard_op(
477 478
                        mod, DPMPPP_MESH_LIST[mod.mesh_idx]
                    )(output, memory, tgt_mask, use_cache, cache)
479
                    auto.shard_tensor(
480 481 482 483
                        output,
                        DPMPPP_MESH_LIST[mod.mesh_idx],
                        ["x"] + [None for i in range(len(output.shape) - 1)],
                    )
484
                else:
485 486 487 488 489 490 491
                    output, new_cache = mod(
                        output,
                        memory,
                        tgt_mask=tgt_mask,
                        use_cache=use_cache,
                        cache=cache[i],
                    )
492 493 494 495 496 497 498 499 500 501 502 503 504
                new_caches.append(new_cache)
            self.checkpoints.append(output.name)
        if self.norm is not None:
            output = self.norm(output)
        return output if use_cache is False else (output, new_caches)

    def gen_cache(self, memory, do_zip=False):
        """
        Generates cache for `forward` usage. The generated cache is a list, and
        each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
        produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
        for more details. If `do_zip` is True, apply `zip` on these tuples to get
        a list with two elements.
505
        """
506 507 508 509 510 511 512 513 514 515 516 517
        cache = [layer.gen_cache(memory) for layer in self.layers]
        if do_zip:
            cache = list(zip(*cache))
        return cache


class TransformerDecoderLayer(nn.Layer):
    """
    The transformer decoder layer.
    It contains multiheadattention and some linear layers.
    """

518 519 520 521 522 523 524 525 526 527 528 529 530 531
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward,
        dropout=0.1,
        activation="gelu",
        attn_dropout=None,
        act_dropout=None,
        normalize_before=True,
        weight_attr=None,
        bias_attr=None,
        mesh_idx=None,
    ):
532 533 534 535
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3
        self.mesh_idx = mesh_idx
536
        super().__init__()
537 538 539 540 541
        attn_dropout = dropout if attn_dropout is None else attn_dropout
        act_dropout = dropout if act_dropout is None else act_dropout
        self.normalize_before = normalize_before
        weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
542 543 544 545 546 547 548 549 550 551 552 553 554 555
        self.self_attn = MultiHeadAttention(
            d_model,
            nhead,
            dropout=attn_dropout,
            weight_attr=weight_attrs[0],
            bias_attr=bias_attrs[0],
            mesh_idx=self.mesh_idx,
        )
        self.linear1 = nn.Linear(
            d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
        )
        self.linear2 = nn.Linear(
            dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]
        )
556 557 558 559 560 561 562 563 564 565 566 567 568
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
        self.activation = getattr(F, activation)

    def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        if use_cache is False:
            tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
        else:
569 570 571
            tgt, incremental_cache = self.self_attn(
                tgt, tgt, tgt, tgt_mask, use_cache, cache
            )
572 573 574 575 576 577 578
        tgt = residual + self.dropout1(tgt)
        if not self.normalize_before:
            tgt = self.norm1(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)
        if _global_parallel_strategy == "mp":
579 580 581
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "x"]
            )
582
        elif _global_parallel_strategy == "dp_mp":
583 584 585
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "y"]
            )
586
        elif _global_parallel_strategy == "mp_pp":
587 588 589
            auto.shard_tensor(
                self.linear1.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
590
        if _global_parallel_strategy == "dp_mp_pp":
591 592 593 594 595
            auto.shard_tensor(
                self.linear1.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                [None, "y"],
            )
596

597
        if _global_parallel_strategy == "mp":
598 599 600
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["x", None]
            )
601
        elif _global_parallel_strategy == "dp_mp":
602 603 604
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["y", None]
            )
605
        elif _global_parallel_strategy == "mp_pp":
606 607 608
            auto.shard_tensor(
                self.linear2.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
609
        elif _global_parallel_strategy == "dp_mp_pp":
610 611 612 613 614
            auto.shard_tensor(
                self.linear2.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
615
        tgt = self.dropout2(
616 617
            self.linear2(F.gelu(self.linear1(tgt), approximate=True))
        )
618 619 620 621 622 623
        tgt = residual + tgt
        if not self.normalize_before:
            tgt = self.norm2(tgt)
        return tgt if use_cache is False else (tgt, incremental_cache)

    def gen_cache(self, memory):
624 625 626
        incremental_cache = self.self_attn.gen_cache(
            memory, type=self.self_attn.Cache
        )
627 628 629 630 631 632 633 634
        return incremental_cache


class GPTEmbeddings(nn.Layer):
    """
    Include embeddings from word, position and token_type embeddings
    """

635 636 637 638 639 640 641 642 643
    def __init__(
        self,
        vocab_size,
        hidden_size=768,
        hidden_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
    ):
644
        super().__init__()
645 646 647
        self.word_embeddings = nn.Embedding(
            vocab_size,
            hidden_size,
648 649 650 651 652 653 654
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
655 656 657
        self.position_embeddings = nn.Embedding(
            max_position_embeddings,
            hidden_size,
658 659 660 661 662 663 664
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
665 666 667 668 669 670 671 672 673
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
        input_embedings = self.word_embeddings(input_ids)
        if _global_parallel_strategy == "mp":
674 675 676
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["x", None]
            )
677
        elif _global_parallel_strategy == "dp_mp":
678 679 680
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["y", None]
            )
681
        elif _global_parallel_strategy == "mp_pp":
682 683 684
            auto.shard_tensor(
                self.word_embeddings.weight, MPPP_MESH_LIST[0], ["x", None]
            )
685
        elif _global_parallel_strategy == "dp_mp_pp":
686 687 688
            auto.shard_tensor(
                self.word_embeddings.weight, DPMPPP_MESH_LIST[0], ["y", None]
            )
689

690 691 692 693 694 695 696 697 698 699 700
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


class GPTModel(nn.Layer):
    """
    The base model of gpt.
    """

701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
    def __init__(
        self,
        vocab_size=50304,
        hidden_size=1024,
        num_hidden_layers=24,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
        pad_token_id=0,
        eos_token_id=7,
        bos_token_id=0,
        eol_token_id=3,
        pp_degree=None,
    ):
720
        super().__init__()
721 722 723 724 725
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.layer_per_stage = None
726
        self.pipline_mode = pp_degree is not None and pp_degree > 1
727 728
        if self.pipline_mode:
            self.layer_per_stage = num_hidden_layers // pp_degree
729 730 731 732 733 734 735 736
        self.embeddings = GPTEmbeddings(
            vocab_size,
            hidden_size,
            hidden_dropout_prob,
            max_position_embeddings,
            type_vocab_size,
            self.initializer_range,
        )
737 738 739 740 741 742 743
        decoder_layers = nn.LayerList()
        for i in range(num_hidden_layers):
            mesh_index = None
            DecoderLayer = TransformerDecoderLayer
            if self.layer_per_stage is not None:
                mesh_index = i // self.layer_per_stage
            decoder_layers.append(
744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760
                DecoderLayer(
                    d_model=hidden_size,
                    nhead=num_attention_heads,
                    dim_feedforward=intermediate_size,
                    dropout=hidden_dropout_prob,
                    activation=hidden_act,
                    attn_dropout=attention_probs_dropout_prob,
                    act_dropout=hidden_dropout_prob,
                    weight_attr=paddle.ParamAttr(
                        initializer=nn.initializer.Normal(
                            mean=0.0, std=self.initializer_range
                        )
                    ),
                    bias_attr=None,
                    mesh_idx=mesh_index,
                )
            )
761
        Decoder = TransformerDecoder
762 763 764 765 766 767
        self.decoder = Decoder(
            decoder_layers,
            num_hidden_layers,
            norm="LayerNorm",
            hidden_size=hidden_size,
        )
768 769
        self.checkpoints = []

770 771 772 773 774 775 776 777
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        use_cache=False,
        cache=None,
    ):
778 779 780 781 782
        self.checkpoints = []
        if position_ids is None:
            past_length = 0
            if cache is not None:
                past_length = paddle.shape(cache[0].k)[-2]
783 784 785 786 787
            position_ids = paddle.arange(
                past_length,
                paddle.shape(input_ids)[-1] + past_length,
                dtype='int64',
            )
788
            position_ids = position_ids.unsqueeze(0)
789
            position_ids = paddle.fluid.layers.expand_as(
790 791 792 793 794
                position_ids, input_ids
            )
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids
        )
795
        if _global_parallel_strategy == "pp":
796 797 798 799 800
            auto.shard_tensor(
                input_ids,
                PP_MESH_LIST[0],
                [None for i in range(len(input_ids.shape))],
            )
801
        if _global_parallel_strategy == "dp_pp":
802 803 804 805 806
            auto.shard_tensor(
                input_ids,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
807
        if _global_parallel_strategy == "dp_mp_pp":
808 809 810 811 812 813 814 815 816 817 818 819
            auto.shard_tensor(
                input_ids,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
        encoder_outputs = self.decoder(
            embedding_output,
            memory=None,
            tgt_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
820 821 822 823 824 825 826 827 828 829 830
        self.checkpoints.extend(self.decoder.checkpoints)
        return encoder_outputs


class GPTForPretraining(nn.Layer):
    """
    The pretraining model of GPT.
    It returns some logits and cached_kvs.
    """

    def __init__(
831 832 833 834 835 836
        self,
        gpt,
        vocab_size=50304,
        hidden_size=768,
        initializer_range=0.02,
    ):
837
        super().__init__()
838 839
        self.gpt = gpt

840 841 842 843 844 845 846 847 848
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        masked_positions=None,
        use_cache=False,
        cache=None,
    ):
849 850 851 852
        input_ids.stop_gradient = True
        position_ids.stop_gradient = True
        attention_mask.stop_gradient = True

853 854 855 856 857 858 859
        outputs = self.gpt(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
860 861 862 863
        if use_cache:
            encoder_outputs, cached_kvs = outputs[:2]
        else:
            encoder_outputs = outputs
864 865 866 867

        x = encoder_outputs
        w = self.gpt.embeddings.word_embeddings.weight

868
        mesh = None
869 870
        if _global_parallel_strategy == "pp":
            mesh = PP_MESH_LIST[-1]
871 872
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = [None for i in range(len(w.shape))]
873
        elif _global_parallel_strategy == "dp":
874 875 876
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
877
        elif _global_parallel_strategy == "mp":
878 879 880
            mesh = _global_process_mesh
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
881
        elif _global_parallel_strategy == "dp_mp":
882 883 884
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
885 886
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
887 888
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
889 890
        elif _global_parallel_strategy == "mp_pp":
            mesh = MPPP_MESH_LIST[-1]
891 892
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
893 894
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
895 896 897 898
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]

        if mesh:
899 900 901
            matmul = auto.shard_op(
                paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
            )
902 903 904
            logits = matmul(x, w, transpose_y=True)
        else:
            logits = paddle.matmul(x, w, transpose_y=True)
905

906 907 908 909 910 911 912 913 914 915 916 917 918
        if use_cache:
            return logits, cached_kvs
        else:
            return logits


class GPTPretrainingCriterion(nn.Layer):
    """
    Criterion for GPT.
    It calculates the final loss.
    """

    def __init__(self):
919
        super().__init__()
920 921 922
        self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")

    def forward(self, prediction_scores, masked_lm_labels, loss_mask):
923 924
        masked_lm_labels.stop_gradient = True
        loss_mask.stop_gradient = True
925

926
        mesh = None
927
        if _global_parallel_strategy == "dp":
928
            mesh = _global_process_mesh
929 930 931
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
932
        elif _global_parallel_strategy == "dp_mp":
933
            mesh = _global_process_mesh
934 935 936
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
937 938
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
939 940 941
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
942 943
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
944 945 946
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
947

948 949
        if mesh:
            auto.shard_tensor(loss_mask, mesh, dims_mapping)
950

951 952 953
        masked_lm_loss = self.loss_func(
            prediction_scores, masked_lm_labels.unsqueeze(2)
        )
954 955 956
        loss_mask = loss_mask.reshape([-1])
        masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
        total_loss = masked_lm_loss / loss_mask.sum()
Z
zhaoyingli 已提交
957
        return total_loss