auto_parallel_gpt_model.py 30.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
21
from paddle.distributed.fleet import auto
22 23 24 25 26 27 28 29
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list

paddle.enable_static()


def init_global():
    global _global_parallel_strategy
30
    _global_parallel_strategy = None
31 32 33 34 35 36 37 38 39 40 41 42 43
    global _global_process_mesh
    global PP_MESH_LIST
    global DPPP_MESH_LIST
    global MPPP_MESH_LIST
    global DPMPPP_MESH_LIST


class MultiHeadAttention(nn.Layer):
    """
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    """
44

45 46 47
    Cache = collections.namedtuple("Cache", ["k", "v"])
    StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

48 49 50 51 52 53 54 55 56 57 58 59
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        kdim=None,
        vdim=None,
        need_weights=False,
        weight_attr=None,
        bias_attr=None,
        fuse=False,
        mesh_idx=None,
60 61
        use_new_recompute=False,
        recompute_granularity="full",
62
    ):
63
        super().__init__()
64 65 66 67 68 69 70 71
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.need_weights = need_weights
        self.fuse = fuse
        self.mesh_idx = mesh_idx
72 73 74
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

75
        self.head_dim = embed_dim // num_heads
76 77 78
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
79 80 81
        if self.fuse:
            assert self.kdim == embed_dim
            assert self.vdim == embed_dim
82 83 84
            self.qkv_proj = nn.Linear(
                embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr
            )
85
        else:
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            self.q_proj = nn.Linear(
                embed_dim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.k_proj = nn.Linear(
                self.kdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.v_proj = nn.Linear(
                self.vdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr
        )
107 108 109

    def _fuse_prepare_qkv(self, query):
        mix_layer = self.qkv_proj(query)
110 111 112
        mix_layer = paddle.reshape_(
            mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]
        )
113 114 115 116 117 118 119 120 121 122 123 124
        mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
        q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
        return q, k, v

    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        """
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.
        """
        q = self.q_proj(query)
        if _global_parallel_strategy == "mp":
125 126 127
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "x"]
            )
128
        elif _global_parallel_strategy == "dp_mp":
129 130 131
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "y"]
            )
132
        elif _global_parallel_strategy == "mp_pp":
133 134 135
            auto.shard_tensor(
                self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
136
        elif _global_parallel_strategy == "dp_mp_pp":
137 138 139
            auto.shard_tensor(
                self.q_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)
        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)
        return (q, k, v) if use_cache is False else (q, k, v, cache)

    def compute_kv(self, key, value):
        """
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.
        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.
        """
        k = self.k_proj(key)
        if _global_parallel_strategy == "mp":
168 169 170
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "x"]
            )
171
        elif _global_parallel_strategy == "dp_mp":
172 173 174
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "y"]
            )
175
        elif _global_parallel_strategy == "mp_pp":
176 177 178
            auto.shard_tensor(
                self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
179
        elif _global_parallel_strategy == "dp_mp_pp":
180 181 182
            auto.shard_tensor(
                self.k_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
183 184
        v = self.v_proj(value)
        if _global_parallel_strategy == "mp":
185 186 187
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "x"]
            )
188
        elif _global_parallel_strategy == "dp_mp":
189 190 191
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "y"]
            )
192
        elif _global_parallel_strategy == "mp_pp":
193 194 195
            auto.shard_tensor(
                self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
196
        elif _global_parallel_strategy == "dp_mp_pp":
197 198 199
            auto.shard_tensor(
                self.v_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v

    def gen_cache(self, key, value=None, type=Cache):
        """
        Generates cache for `forward` usage in inference accroding to arguments.
        The generated cache is an instance of `MultiHeadAttention.Cache` or an
        instance of `MultiHeadAttention.StaticCache`.
        """
        if type == MultiHeadAttention.StaticCache:  # static_kv
            k, v = self.compute_kv(key, value)
            return self.StaticCache(k, v)
        elif value is None:  # incremental_state
            k = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
220 221
                value=0,
            )
222 223 224 225
            v = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
226 227
                value=0,
            )
228 229 230 231 232
            return self.Cache(k, v)
        else:
            # incremental_state with initial value, mainly for usage like UniLM
            return self.Cache(key, value)

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    def core_attn(self, q, k, v, attn_mask):
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
        if attn_mask is not None:
            product = product + attn_mask
        weights = F.softmax(product)
        if self.dropout:
            weights = F.dropout(
                weights,
                self.dropout,
                training=self.training,
                mode="upscale_in_train",
            )
        out = tensor.matmul(weights, v)
        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        return out, weights

254 255 256
    def forward(
        self, query, key, value, attn_mask=None, use_cache=False, cache=None
    ):
257 258 259 260 261 262 263 264 265 266 267 268 269
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
        else:
270 271 272
            q, k, v, cache = self._prepare_qkv(
                query, key, value, use_cache, cache
            )
273 274 275 276 277 278

        if self.use_new_recompute and self.recompute_granularity == "core_attn":
            out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
        else:
            out, weights = self.core_attn(q, k, v, attn_mask)

279 280 281
        # project to output
        out = self.out_proj(out)
        if _global_parallel_strategy == "mp":
282 283 284
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["x", None]
            )
285
        elif _global_parallel_strategy == "dp_mp":
286 287 288
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["y", None]
            )
289
        elif _global_parallel_strategy == "mp_pp":
290 291 292
            auto.shard_tensor(
                self.out_proj.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
293
        elif _global_parallel_strategy == "dp_mp_pp":
294 295 296 297 298
            auto.shard_tensor(
                self.out_proj.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
299

300 301 302 303 304 305 306 307 308 309 310 311 312
        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)


class TransformerDecoder(nn.Layer):
    """
    TransformerDecoder is a stack of N decoder layers.
    """

313 314 315 316 317 318 319 320 321
    def __init__(
        self,
        decoder_layers,
        num_layers,
        norm=None,
        hidden_size=None,
        use_new_recompute=False,
        recompute_granularity="full",
    ):
322
        super().__init__()
323 324 325 326

        self.num_layers = num_layers
        self.layers = decoder_layers
        self.norm = norm
327 328
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity
329
        if norm == "LayerNorm":
330 331 332 333 334
            self.norm = nn.LayerNorm(hidden_size)
        elif norm is not None:
            raise ValueError("Only support LayerNorm")
        self.checkpoints = []

335 336 337 338 339 340 341 342 343
    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        use_cache=False,
        cache=None,
    ):
344 345 346 347 348 349 350 351 352
        """
        Applies a stack of N Transformer decoder layers on inputs. If `norm` is
        provided, also applies layer normalization on the output of last decoder
        layer.
        """
        output = tgt
        new_caches = []
        self.checkpoints = []
        if _global_parallel_strategy == "pp":
353 354 355 356 357
            auto.shard_tensor(
                output,
                PP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
358
        if _global_parallel_strategy == "dp_pp":
359 360 361 362 363
            auto.shard_tensor(
                output,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
364
        if _global_parallel_strategy == "mp_pp":
365 366 367 368 369
            auto.shard_tensor(
                output,
                MPPP_MESH_LIST[0],
                [None for i in range(len(output.shape))],
            )
370
        if _global_parallel_strategy == "dp_mp_pp":
371 372 373 374 375
            auto.shard_tensor(
                output,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(output.shape) - 1)],
            )
376

377
        for i, mod in enumerate(self.layers):
378 379 380
            if self.use_new_recompute and self.recompute_granularity == "full":
                mod = auto.recompute(mod)

381 382
            if cache is None:
                if use_cache:
383 384 385 386 387
                    output, new_cache = mod(
                        output,
                        memory,
                        tgt_mask=tgt_mask,
                        use_cache=use_cache,
388
                        cache=cache,
389
                    )
390 391 392 393 394 395 396 397 398 399 400
                    new_caches.append(new_cache)
                else:
                    output = mod(output, memory, tgt_mask, use_cache, cache)
            else:
                output, new_cache = mod(
                    output,
                    memory,
                    tgt_mask=tgt_mask,
                    use_cache=use_cache,
                    cache=cache[i],
                )
401
                new_caches.append(new_cache)
402 403 404 405

            if not self.use_new_recompute:
                self.checkpoints.append(output.name)

406 407 408 409 410 411 412 413 414 415 416
        if self.norm is not None:
            output = self.norm(output)
        return output if use_cache is False else (output, new_caches)

    def gen_cache(self, memory, do_zip=False):
        """
        Generates cache for `forward` usage. The generated cache is a list, and
        each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
        produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
        for more details. If `do_zip` is True, apply `zip` on these tuples to get
        a list with two elements.
417
        """
418 419 420 421 422 423 424 425 426 427 428 429
        cache = [layer.gen_cache(memory) for layer in self.layers]
        if do_zip:
            cache = list(zip(*cache))
        return cache


class TransformerDecoderLayer(nn.Layer):
    """
    The transformer decoder layer.
    It contains multiheadattention and some linear layers.
    """

430 431 432 433 434 435 436 437 438 439 440 441 442
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward,
        dropout=0.1,
        activation="gelu",
        attn_dropout=None,
        act_dropout=None,
        normalize_before=True,
        weight_attr=None,
        bias_attr=None,
        mesh_idx=None,
443 444
        use_new_recompute=False,
        recompute_granularity="full",
445
    ):
446 447 448 449
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3
        self.mesh_idx = mesh_idx
450
        super().__init__()
451 452 453
        attn_dropout = dropout if attn_dropout is None else attn_dropout
        act_dropout = dropout if act_dropout is None else act_dropout
        self.normalize_before = normalize_before
454 455 456
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

457 458
        weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
459

460 461 462 463 464 465 466
        self.self_attn = MultiHeadAttention(
            d_model,
            nhead,
            dropout=attn_dropout,
            weight_attr=weight_attrs[0],
            bias_attr=bias_attrs[0],
            mesh_idx=self.mesh_idx,
467 468
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
469 470 471 472 473 474 475
        )
        self.linear1 = nn.Linear(
            d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
        )
        self.linear2 = nn.Linear(
            dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]
        )
476 477 478 479 480 481 482 483 484 485
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
        self.activation = getattr(F, activation)

    def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
486 487 488 489 490 491

        if self.use_new_recompute and self.recompute_granularity == "full_attn":
            self_attn = auto.recompute(self.self_attn)
        else:
            self_attn = self.self_attn

492
        if use_cache is False:
493
            tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
494
        else:
495
            tgt, incremental_cache = self_attn(
496 497
                tgt, tgt, tgt, tgt_mask, use_cache, cache
            )
498

499 500 501 502 503 504 505
        tgt = residual + self.dropout1(tgt)
        if not self.normalize_before:
            tgt = self.norm1(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)
        if _global_parallel_strategy == "mp":
506 507 508
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "x"]
            )
509
        elif _global_parallel_strategy == "dp_mp":
510 511 512
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "y"]
            )
513
        elif _global_parallel_strategy == "mp_pp":
514 515 516
            auto.shard_tensor(
                self.linear1.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
517
        if _global_parallel_strategy == "dp_mp_pp":
518 519 520 521 522
            auto.shard_tensor(
                self.linear1.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                [None, "y"],
            )
523

524
        if _global_parallel_strategy == "mp":
525 526 527
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["x", None]
            )
528
        elif _global_parallel_strategy == "dp_mp":
529 530 531
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["y", None]
            )
532
        elif _global_parallel_strategy == "mp_pp":
533 534 535
            auto.shard_tensor(
                self.linear2.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
536
        elif _global_parallel_strategy == "dp_mp_pp":
537 538 539 540 541
            auto.shard_tensor(
                self.linear2.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
542
        tgt = self.dropout2(
543 544
            self.linear2(F.gelu(self.linear1(tgt), approximate=True))
        )
545 546 547 548 549 550
        tgt = residual + tgt
        if not self.normalize_before:
            tgt = self.norm2(tgt)
        return tgt if use_cache is False else (tgt, incremental_cache)

    def gen_cache(self, memory):
551 552 553
        incremental_cache = self.self_attn.gen_cache(
            memory, type=self.self_attn.Cache
        )
554 555 556 557 558 559 560 561
        return incremental_cache


class GPTEmbeddings(nn.Layer):
    """
    Include embeddings from word, position and token_type embeddings
    """

562 563 564 565 566 567 568 569 570
    def __init__(
        self,
        vocab_size,
        hidden_size=768,
        hidden_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
    ):
571
        super().__init__()
572 573 574
        self.word_embeddings = nn.Embedding(
            vocab_size,
            hidden_size,
575 576 577 578 579 580 581
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
582 583 584
        self.position_embeddings = nn.Embedding(
            max_position_embeddings,
            hidden_size,
585 586 587 588 589 590 591
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
592 593 594 595 596 597 598 599 600
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
        input_embedings = self.word_embeddings(input_ids)
        if _global_parallel_strategy == "mp":
601 602 603
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["x", None]
            )
604
        elif _global_parallel_strategy == "dp_mp":
605 606 607
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["y", None]
            )
608
        elif _global_parallel_strategy == "mp_pp":
609 610 611
            auto.shard_tensor(
                self.word_embeddings.weight, MPPP_MESH_LIST[0], ["x", None]
            )
612
        elif _global_parallel_strategy == "dp_mp_pp":
613 614 615
            auto.shard_tensor(
                self.word_embeddings.weight, DPMPPP_MESH_LIST[0], ["y", None]
            )
616

617 618 619 620 621 622 623 624 625 626 627
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


class GPTModel(nn.Layer):
    """
    The base model of gpt.
    """

628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
    def __init__(
        self,
        vocab_size=50304,
        hidden_size=1024,
        num_hidden_layers=24,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
        pad_token_id=0,
        eos_token_id=7,
        bos_token_id=0,
        eol_token_id=3,
        pp_degree=None,
646 647
        use_new_recompute=False,
        recompute_granularity="full",
648
    ):
649
        super().__init__()
650 651 652 653
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
654 655 656
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

657
        self.layer_per_stage = None
658
        self.pipline_mode = pp_degree is not None and pp_degree > 1
659 660
        if self.pipline_mode:
            self.layer_per_stage = num_hidden_layers // pp_degree
661 662 663 664 665 666 667 668
        self.embeddings = GPTEmbeddings(
            vocab_size,
            hidden_size,
            hidden_dropout_prob,
            max_position_embeddings,
            type_vocab_size,
            self.initializer_range,
        )
669

670 671 672 673 674 675 676
        decoder_layers = nn.LayerList()
        for i in range(num_hidden_layers):
            mesh_index = None
            DecoderLayer = TransformerDecoderLayer
            if self.layer_per_stage is not None:
                mesh_index = i // self.layer_per_stage
            decoder_layers.append(
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
                DecoderLayer(
                    d_model=hidden_size,
                    nhead=num_attention_heads,
                    dim_feedforward=intermediate_size,
                    dropout=hidden_dropout_prob,
                    activation=hidden_act,
                    attn_dropout=attention_probs_dropout_prob,
                    act_dropout=hidden_dropout_prob,
                    weight_attr=paddle.ParamAttr(
                        initializer=nn.initializer.Normal(
                            mean=0.0, std=self.initializer_range
                        )
                    ),
                    bias_attr=None,
                    mesh_idx=mesh_index,
692 693
                    use_new_recompute=self.use_new_recompute,
                    recompute_granularity=self.recompute_granularity,
694 695
                )
            )
696

697
        Decoder = TransformerDecoder
698 699 700 701 702
        self.decoder = Decoder(
            decoder_layers,
            num_hidden_layers,
            norm="LayerNorm",
            hidden_size=hidden_size,
703 704
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
705
        )
706 707
        self.checkpoints = []

708 709 710 711 712 713 714 715
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        use_cache=False,
        cache=None,
    ):
716 717 718 719 720
        self.checkpoints = []
        if position_ids is None:
            past_length = 0
            if cache is not None:
                past_length = paddle.shape(cache[0].k)[-2]
721 722 723 724 725
            position_ids = paddle.arange(
                past_length,
                paddle.shape(input_ids)[-1] + past_length,
                dtype='int64',
            )
726
            position_ids = position_ids.unsqueeze(0)
727
            position_ids = paddle.fluid.layers.expand_as(
728 729 730 731 732
                position_ids, input_ids
            )
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids
        )
733
        if _global_parallel_strategy == "pp":
734 735 736 737 738
            auto.shard_tensor(
                input_ids,
                PP_MESH_LIST[0],
                [None for i in range(len(input_ids.shape))],
            )
739
        if _global_parallel_strategy == "dp_pp":
740 741 742 743 744
            auto.shard_tensor(
                input_ids,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
745
        if _global_parallel_strategy == "dp_mp_pp":
746 747 748 749 750 751 752 753 754 755 756 757
            auto.shard_tensor(
                input_ids,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
        encoder_outputs = self.decoder(
            embedding_output,
            memory=None,
            tgt_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
758 759
        if not self.use_new_recompute:
            self.checkpoints.extend(self.decoder.checkpoints)
760 761 762 763 764 765 766 767 768 769
        return encoder_outputs


class GPTForPretraining(nn.Layer):
    """
    The pretraining model of GPT.
    It returns some logits and cached_kvs.
    """

    def __init__(
770 771 772 773 774 775
        self,
        gpt,
        vocab_size=50304,
        hidden_size=768,
        initializer_range=0.02,
    ):
776
        super().__init__()
777 778
        self.gpt = gpt

779 780 781 782 783 784 785 786 787
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        masked_positions=None,
        use_cache=False,
        cache=None,
    ):
788 789 790 791
        input_ids.stop_gradient = True
        position_ids.stop_gradient = True
        attention_mask.stop_gradient = True

792 793 794 795 796 797 798
        outputs = self.gpt(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
799 800 801 802
        if use_cache:
            encoder_outputs, cached_kvs = outputs[:2]
        else:
            encoder_outputs = outputs
803 804 805 806

        x = encoder_outputs
        w = self.gpt.embeddings.word_embeddings.weight

807
        mesh = None
808 809
        if _global_parallel_strategy == "pp":
            mesh = PP_MESH_LIST[-1]
810 811
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = [None for i in range(len(w.shape))]
812
        elif _global_parallel_strategy == "dp":
813 814 815
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
816
        elif _global_parallel_strategy == "mp":
817 818 819
            mesh = _global_process_mesh
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
820
        elif _global_parallel_strategy == "dp_mp":
821 822 823
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
824 825
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
826 827
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
828 829
        elif _global_parallel_strategy == "mp_pp":
            mesh = MPPP_MESH_LIST[-1]
830 831
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
832 833
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
834 835 836 837
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]

        if mesh:
838 839 840
            matmul = auto.shard_op(
                paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
            )
841 842 843
            logits = matmul(x, w, transpose_y=True)
        else:
            logits = paddle.matmul(x, w, transpose_y=True)
844

845 846 847 848 849 850 851 852 853 854 855 856 857
        if use_cache:
            return logits, cached_kvs
        else:
            return logits


class GPTPretrainingCriterion(nn.Layer):
    """
    Criterion for GPT.
    It calculates the final loss.
    """

    def __init__(self):
858
        super().__init__()
859 860 861
        self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")

    def forward(self, prediction_scores, masked_lm_labels, loss_mask):
862 863
        masked_lm_labels.stop_gradient = True
        loss_mask.stop_gradient = True
864

865
        mesh = None
866
        if _global_parallel_strategy == "dp":
867
            mesh = _global_process_mesh
868 869 870
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
871
        elif _global_parallel_strategy == "dp_mp":
872
            mesh = _global_process_mesh
873 874 875
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
876 877
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
878 879 880
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
881 882
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
883 884 885
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
886

887 888
        if mesh:
            auto.shard_tensor(loss_mask, mesh, dims_mapping)
889

890 891 892
        masked_lm_loss = self.loss_func(
            prediction_scores, masked_lm_labels.unsqueeze(2)
        )
893 894 895
        loss_mask = loss_mask.reshape([-1])
        masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
        total_loss = masked_lm_loss / loss_mask.sum()
Z
zhaoyingli 已提交
896
        return total_loss