auto_parallel_gpt_model.py 30.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
21
from paddle.distributed.fleet import auto
22 23 24 25 26 27 28 29
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list

paddle.enable_static()


def init_global():
    global _global_parallel_strategy
30
    _global_parallel_strategy = None
31 32 33 34 35 36 37 38 39 40 41 42 43
    global _global_process_mesh
    global PP_MESH_LIST
    global DPPP_MESH_LIST
    global MPPP_MESH_LIST
    global DPMPPP_MESH_LIST


class MultiHeadAttention(nn.Layer):
    """
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    """
44

45 46 47
    Cache = collections.namedtuple("Cache", ["k", "v"])
    StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

48 49 50 51 52 53 54 55 56 57 58 59
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        kdim=None,
        vdim=None,
        need_weights=False,
        weight_attr=None,
        bias_attr=None,
        fuse=False,
        mesh_idx=None,
60 61
        use_new_recompute=False,
        recompute_granularity="full",
62
    ):
63
        super().__init__()
64 65 66 67 68 69 70 71
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.need_weights = need_weights
        self.fuse = fuse
        self.mesh_idx = mesh_idx
72 73 74
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

75
        self.head_dim = embed_dim // num_heads
76 77 78
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
79 80 81
        if self.fuse:
            assert self.kdim == embed_dim
            assert self.vdim == embed_dim
82 83 84
            self.qkv_proj = nn.Linear(
                embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr
            )
85
        else:
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            self.q_proj = nn.Linear(
                embed_dim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.k_proj = nn.Linear(
                self.kdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.v_proj = nn.Linear(
                self.vdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr
        )
107 108 109

    def _fuse_prepare_qkv(self, query):
        mix_layer = self.qkv_proj(query)
110 111 112
        mix_layer = paddle.reshape_(
            mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]
        )
113 114 115 116 117 118 119 120 121 122 123 124
        mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
        q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
        return q, k, v

    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        """
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.
        """
        q = self.q_proj(query)
        if _global_parallel_strategy == "mp":
125 126 127
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "x"]
            )
128
        elif _global_parallel_strategy == "dp_mp":
129 130 131
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "y"]
            )
132
        elif _global_parallel_strategy == "mp_pp":
133 134 135
            auto.shard_tensor(
                self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
136
        elif _global_parallel_strategy == "dp_mp_pp":
137 138 139
            auto.shard_tensor(
                self.q_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)
        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)
        return (q, k, v) if use_cache is False else (q, k, v, cache)

    def compute_kv(self, key, value):
        """
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.
        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.
        """
        k = self.k_proj(key)
        if _global_parallel_strategy == "mp":
168 169 170
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "x"]
            )
171
        elif _global_parallel_strategy == "dp_mp":
172 173 174
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "y"]
            )
175
        elif _global_parallel_strategy == "mp_pp":
176 177 178
            auto.shard_tensor(
                self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
179
        elif _global_parallel_strategy == "dp_mp_pp":
180 181 182
            auto.shard_tensor(
                self.k_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
183 184
        v = self.v_proj(value)
        if _global_parallel_strategy == "mp":
185 186 187
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "x"]
            )
188
        elif _global_parallel_strategy == "dp_mp":
189 190 191
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "y"]
            )
192
        elif _global_parallel_strategy == "mp_pp":
193 194 195
            auto.shard_tensor(
                self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
196
        elif _global_parallel_strategy == "dp_mp_pp":
197 198 199
            auto.shard_tensor(
                self.v_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v

    def gen_cache(self, key, value=None, type=Cache):
        """
        Generates cache for `forward` usage in inference accroding to arguments.
        The generated cache is an instance of `MultiHeadAttention.Cache` or an
        instance of `MultiHeadAttention.StaticCache`.
        """
        if type == MultiHeadAttention.StaticCache:  # static_kv
            k, v = self.compute_kv(key, value)
            return self.StaticCache(k, v)
        elif value is None:  # incremental_state
            k = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
220 221
                value=0,
            )
222 223 224 225
            v = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
226 227
                value=0,
            )
228 229 230 231 232
            return self.Cache(k, v)
        else:
            # incremental_state with initial value, mainly for usage like UniLM
            return self.Cache(key, value)

233
    def core_attn(self, q, k, v, attn_mask):
K
kangguangli 已提交
234 235 236 237
        product = paddle.matmul(x=q, y=k, transpose_y=True)
        product = paddle.multiply(
            product,
            paddle.to_tensor(self.head_dim**-0.5, dtype=product.dtype),
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
        )
        if attn_mask is not None:
            product = product + attn_mask
        weights = F.softmax(product)
        if self.dropout:
            weights = F.dropout(
                weights,
                self.dropout,
                training=self.training,
                mode="upscale_in_train",
            )
        out = tensor.matmul(weights, v)
        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        return out, weights

256 257 258
    def forward(
        self, query, key, value, attn_mask=None, use_cache=False, cache=None
    ):
259 260 261 262 263 264 265 266 267 268 269 270 271
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
        else:
272 273 274
            q, k, v, cache = self._prepare_qkv(
                query, key, value, use_cache, cache
            )
275 276 277 278 279 280

        if self.use_new_recompute and self.recompute_granularity == "core_attn":
            out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
        else:
            out, weights = self.core_attn(q, k, v, attn_mask)

281 282 283
        # project to output
        out = self.out_proj(out)
        if _global_parallel_strategy == "mp":
284 285 286
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["x", None]
            )
287
        elif _global_parallel_strategy == "dp_mp":
288 289 290
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["y", None]
            )
291
        elif _global_parallel_strategy == "mp_pp":
292 293 294
            auto.shard_tensor(
                self.out_proj.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
295
        elif _global_parallel_strategy == "dp_mp_pp":
296 297 298 299 300
            auto.shard_tensor(
                self.out_proj.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
301

302 303 304 305 306 307 308 309 310 311 312 313 314
        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)


class TransformerDecoder(nn.Layer):
    """
    TransformerDecoder is a stack of N decoder layers.
    """

315 316 317 318 319 320 321 322 323
    def __init__(
        self,
        decoder_layers,
        num_layers,
        norm=None,
        hidden_size=None,
        use_new_recompute=False,
        recompute_granularity="full",
    ):
324
        super().__init__()
325 326 327 328

        self.num_layers = num_layers
        self.layers = decoder_layers
        self.norm = norm
329 330
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity
331
        if norm == "LayerNorm":
332 333 334 335 336
            self.norm = nn.LayerNorm(hidden_size)
        elif norm is not None:
            raise ValueError("Only support LayerNorm")
        self.checkpoints = []

337 338 339 340 341 342 343 344 345
    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        use_cache=False,
        cache=None,
    ):
346 347 348 349 350 351 352 353 354
        """
        Applies a stack of N Transformer decoder layers on inputs. If `norm` is
        provided, also applies layer normalization on the output of last decoder
        layer.
        """
        output = tgt
        new_caches = []
        self.checkpoints = []
        if _global_parallel_strategy == "pp":
355 356 357 358 359
            auto.shard_tensor(
                output,
                PP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
360
        if _global_parallel_strategy == "dp_pp":
361 362 363 364 365
            auto.shard_tensor(
                output,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
366
        if _global_parallel_strategy == "mp_pp":
367 368 369 370 371
            auto.shard_tensor(
                output,
                MPPP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
372
        if _global_parallel_strategy == "dp_mp_pp":
373 374 375 376 377
            auto.shard_tensor(
                output,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
378

379
        for i, mod in enumerate(self.layers):
380 381 382
            if self.use_new_recompute and self.recompute_granularity == "full":
                mod = auto.recompute(mod)

383 384
            if cache is None:
                if use_cache:
385 386 387 388 389
                    output, new_cache = mod(
                        output,
                        memory,
                        tgt_mask=tgt_mask,
                        use_cache=use_cache,
390
                        cache=cache,
391
                    )
392 393 394 395 396 397 398 399 400 401 402
                    new_caches.append(new_cache)
                else:
                    output = mod(output, memory, tgt_mask, use_cache, cache)
            else:
                output, new_cache = mod(
                    output,
                    memory,
                    tgt_mask=tgt_mask,
                    use_cache=use_cache,
                    cache=cache[i],
                )
403
                new_caches.append(new_cache)
404 405 406 407

            if not self.use_new_recompute:
                self.checkpoints.append(output.name)

408 409 410 411 412 413 414 415 416 417 418
        if self.norm is not None:
            output = self.norm(output)
        return output if use_cache is False else (output, new_caches)

    def gen_cache(self, memory, do_zip=False):
        """
        Generates cache for `forward` usage. The generated cache is a list, and
        each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
        produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
        for more details. If `do_zip` is True, apply `zip` on these tuples to get
        a list with two elements.
419
        """
420 421 422 423 424 425 426 427 428 429 430 431
        cache = [layer.gen_cache(memory) for layer in self.layers]
        if do_zip:
            cache = list(zip(*cache))
        return cache


class TransformerDecoderLayer(nn.Layer):
    """
    The transformer decoder layer.
    It contains multiheadattention and some linear layers.
    """

432 433 434 435 436 437 438 439 440 441 442 443 444
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward,
        dropout=0.1,
        activation="gelu",
        attn_dropout=None,
        act_dropout=None,
        normalize_before=True,
        weight_attr=None,
        bias_attr=None,
        mesh_idx=None,
445 446
        use_new_recompute=False,
        recompute_granularity="full",
447
    ):
448 449 450 451
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3
        self.mesh_idx = mesh_idx
452
        super().__init__()
453 454 455
        attn_dropout = dropout if attn_dropout is None else attn_dropout
        act_dropout = dropout if act_dropout is None else act_dropout
        self.normalize_before = normalize_before
456 457 458
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

459 460
        weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
461

462 463 464 465 466 467 468
        self.self_attn = MultiHeadAttention(
            d_model,
            nhead,
            dropout=attn_dropout,
            weight_attr=weight_attrs[0],
            bias_attr=bias_attrs[0],
            mesh_idx=self.mesh_idx,
469 470
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
471 472 473 474 475 476 477
        )
        self.linear1 = nn.Linear(
            d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
        )
        self.linear2 = nn.Linear(
            dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]
        )
478 479 480 481 482 483 484 485 486 487
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
        self.activation = getattr(F, activation)

    def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
488 489 490 491 492 493

        if self.use_new_recompute and self.recompute_granularity == "full_attn":
            self_attn = auto.recompute(self.self_attn)
        else:
            self_attn = self.self_attn

494
        if use_cache is False:
495
            tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
496
        else:
497
            tgt, incremental_cache = self_attn(
498 499
                tgt, tgt, tgt, tgt_mask, use_cache, cache
            )
500

501 502 503 504 505 506 507
        tgt = residual + self.dropout1(tgt)
        if not self.normalize_before:
            tgt = self.norm1(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)
        if _global_parallel_strategy == "mp":
508 509 510
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "x"]
            )
511
        elif _global_parallel_strategy == "dp_mp":
512 513 514
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "y"]
            )
515
        elif _global_parallel_strategy == "mp_pp":
516 517 518
            auto.shard_tensor(
                self.linear1.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
519
        if _global_parallel_strategy == "dp_mp_pp":
520 521 522 523 524
            auto.shard_tensor(
                self.linear1.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                [None, "y"],
            )
525

526
        if _global_parallel_strategy == "mp":
527 528 529
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["x", None]
            )
530
        elif _global_parallel_strategy == "dp_mp":
531 532 533
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["y", None]
            )
534
        elif _global_parallel_strategy == "mp_pp":
535 536 537
            auto.shard_tensor(
                self.linear2.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
538
        elif _global_parallel_strategy == "dp_mp_pp":
539 540 541 542 543
            auto.shard_tensor(
                self.linear2.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
544
        tgt = self.dropout2(
545 546
            self.linear2(F.gelu(self.linear1(tgt), approximate=True))
        )
547 548 549 550 551 552
        tgt = residual + tgt
        if not self.normalize_before:
            tgt = self.norm2(tgt)
        return tgt if use_cache is False else (tgt, incremental_cache)

    def gen_cache(self, memory):
553 554 555
        incremental_cache = self.self_attn.gen_cache(
            memory, type=self.self_attn.Cache
        )
556 557 558 559 560 561 562 563
        return incremental_cache


class GPTEmbeddings(nn.Layer):
    """
    Include embeddings from word, position and token_type embeddings
    """

564 565 566 567 568 569 570 571 572
    def __init__(
        self,
        vocab_size,
        hidden_size=768,
        hidden_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
    ):
573
        super().__init__()
574 575 576
        self.word_embeddings = nn.Embedding(
            vocab_size,
            hidden_size,
577 578 579 580 581 582 583
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
584 585 586
        self.position_embeddings = nn.Embedding(
            max_position_embeddings,
            hidden_size,
587 588 589 590 591 592 593
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
594 595 596 597 598 599 600 601 602
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
        input_embedings = self.word_embeddings(input_ids)
        if _global_parallel_strategy == "mp":
603 604 605
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["x", None]
            )
606
        elif _global_parallel_strategy == "dp_mp":
607 608 609
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["y", None]
            )
610
        elif _global_parallel_strategy == "mp_pp":
611 612 613
            auto.shard_tensor(
                self.word_embeddings.weight, MPPP_MESH_LIST[0], ["x", None]
            )
614
        elif _global_parallel_strategy == "dp_mp_pp":
615 616 617
            auto.shard_tensor(
                self.word_embeddings.weight, DPMPPP_MESH_LIST[0], ["y", None]
            )
618

619 620 621 622 623 624 625 626 627 628 629
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


class GPTModel(nn.Layer):
    """
    The base model of gpt.
    """

630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
    def __init__(
        self,
        vocab_size=50304,
        hidden_size=1024,
        num_hidden_layers=24,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
        pad_token_id=0,
        eos_token_id=7,
        bos_token_id=0,
        eol_token_id=3,
        pp_degree=None,
648 649
        use_new_recompute=False,
        recompute_granularity="full",
650
    ):
651
        super().__init__()
652 653 654 655
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
656 657 658
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

659
        self.layer_per_stage = None
660
        self.pipline_mode = pp_degree is not None and pp_degree > 1
661 662
        if self.pipline_mode:
            self.layer_per_stage = num_hidden_layers // pp_degree
663 664 665 666 667 668 669 670
        self.embeddings = GPTEmbeddings(
            vocab_size,
            hidden_size,
            hidden_dropout_prob,
            max_position_embeddings,
            type_vocab_size,
            self.initializer_range,
        )
671

672 673 674 675 676 677 678
        decoder_layers = nn.LayerList()
        for i in range(num_hidden_layers):
            mesh_index = None
            DecoderLayer = TransformerDecoderLayer
            if self.layer_per_stage is not None:
                mesh_index = i // self.layer_per_stage
            decoder_layers.append(
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
                DecoderLayer(
                    d_model=hidden_size,
                    nhead=num_attention_heads,
                    dim_feedforward=intermediate_size,
                    dropout=hidden_dropout_prob,
                    activation=hidden_act,
                    attn_dropout=attention_probs_dropout_prob,
                    act_dropout=hidden_dropout_prob,
                    weight_attr=paddle.ParamAttr(
                        initializer=nn.initializer.Normal(
                            mean=0.0, std=self.initializer_range
                        )
                    ),
                    bias_attr=None,
                    mesh_idx=mesh_index,
694 695
                    use_new_recompute=self.use_new_recompute,
                    recompute_granularity=self.recompute_granularity,
696 697
                )
            )
698

699
        Decoder = TransformerDecoder
700 701 702 703 704
        self.decoder = Decoder(
            decoder_layers,
            num_hidden_layers,
            norm="LayerNorm",
            hidden_size=hidden_size,
705 706
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
707
        )
708 709
        self.checkpoints = []

710 711 712 713 714 715 716 717
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        use_cache=False,
        cache=None,
    ):
718 719 720 721 722
        self.checkpoints = []
        if position_ids is None:
            past_length = 0
            if cache is not None:
                past_length = paddle.shape(cache[0].k)[-2]
723 724 725 726 727
            position_ids = paddle.arange(
                past_length,
                paddle.shape(input_ids)[-1] + past_length,
                dtype='int64',
            )
728
            position_ids = position_ids.unsqueeze(0)
729
            position_ids = paddle.expand_as(position_ids, input_ids)
730 731 732
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids
        )
733
        if _global_parallel_strategy == "pp":
734 735 736 737 738
            auto.shard_tensor(
                input_ids,
                PP_MESH_LIST[0],
                [None for i in range(len(input_ids.shape))],
            )
739
        if _global_parallel_strategy == "dp_pp":
740 741 742 743 744
            auto.shard_tensor(
                input_ids,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
745
        if _global_parallel_strategy == "dp_mp_pp":
746 747 748 749 750 751 752 753 754 755 756 757
            auto.shard_tensor(
                input_ids,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
        encoder_outputs = self.decoder(
            embedding_output,
            memory=None,
            tgt_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
758 759
        if not self.use_new_recompute:
            self.checkpoints.extend(self.decoder.checkpoints)
760 761 762 763 764 765 766 767 768 769
        return encoder_outputs


class GPTForPretraining(nn.Layer):
    """
    The pretraining model of GPT.
    It returns some logits and cached_kvs.
    """

    def __init__(
770 771 772 773 774 775
        self,
        gpt,
        vocab_size=50304,
        hidden_size=768,
        initializer_range=0.02,
    ):
776
        super().__init__()
777 778
        self.gpt = gpt

779 780 781 782 783 784 785 786 787
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        masked_positions=None,
        use_cache=False,
        cache=None,
    ):
788 789 790 791
        input_ids.stop_gradient = True
        position_ids.stop_gradient = True
        attention_mask.stop_gradient = True

792 793 794 795 796 797 798
        outputs = self.gpt(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
799 800 801 802
        if use_cache:
            encoder_outputs, cached_kvs = outputs[:2]
        else:
            encoder_outputs = outputs
803 804 805 806

        x = encoder_outputs
        w = self.gpt.embeddings.word_embeddings.weight

807
        mesh = None
808 809
        if _global_parallel_strategy == "pp":
            mesh = PP_MESH_LIST[-1]
810 811
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = [None for i in range(len(w.shape))]
812
        elif _global_parallel_strategy == "dp":
813 814 815
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
816
        elif _global_parallel_strategy == "mp":
817 818 819
            mesh = _global_process_mesh
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
820
        elif _global_parallel_strategy == "dp_mp":
821 822 823
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
824 825
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
826 827
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
828 829
        elif _global_parallel_strategy == "mp_pp":
            mesh = MPPP_MESH_LIST[-1]
830 831
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
832 833
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
834 835 836 837
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]

        if mesh:
838 839 840
            matmul = auto.shard_op(
                paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
            )
841 842 843
            logits = matmul(x, w, transpose_y=True)
        else:
            logits = paddle.matmul(x, w, transpose_y=True)
844

845 846 847 848 849 850 851 852 853 854 855 856 857
        if use_cache:
            return logits, cached_kvs
        else:
            return logits


class GPTPretrainingCriterion(nn.Layer):
    """
    Criterion for GPT.
    It calculates the final loss.
    """

    def __init__(self):
858
        super().__init__()
859 860 861
        self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")

    def forward(self, prediction_scores, masked_lm_labels, loss_mask):
862 863
        masked_lm_labels.stop_gradient = True
        loss_mask.stop_gradient = True
864

865
        mesh = None
866
        if _global_parallel_strategy == "dp":
867
            mesh = _global_process_mesh
868 869 870
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
871
        elif _global_parallel_strategy == "dp_mp":
872
            mesh = _global_process_mesh
873 874 875
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
876 877
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
878 879 880
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
881 882
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
883 884 885
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
886

887 888
        if mesh:
            auto.shard_tensor(loss_mask, mesh, dims_mapping)
889

890 891 892
        masked_lm_loss = self.loss_func(
            prediction_scores, masked_lm_labels.unsqueeze(2)
        )
893 894 895
        loss_mask = loss_mask.reshape([-1])
        masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
        total_loss = masked_lm_loss / loss_mask.sum()
Z
zhaoyingli 已提交
896
        return total_loss