auto_parallel_gpt_model.py 30.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

import paddle
import paddle.nn.functional as F
19
from paddle import nn, tensor
20
from paddle.distributed.fleet import auto
21 22 23 24 25 26 27
from paddle.nn.layer.transformer import _convert_param_attr_to_list

paddle.enable_static()


def init_global():
    global _global_parallel_strategy
28
    _global_parallel_strategy = None
29 30 31 32 33 34 35 36 37
    global _global_process_mesh
    global PP_MESH_LIST
    global DPPP_MESH_LIST
    global MPPP_MESH_LIST
    global DPMPPP_MESH_LIST


class MultiHeadAttention(nn.Layer):
    """
C
co63oc 已提交
38
    Attention maps queries and a set of key-value pairs to outputs, and
39 40 41
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    """
42

43 44 45
    Cache = collections.namedtuple("Cache", ["k", "v"])
    StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

46 47 48 49 50 51 52 53 54 55 56 57
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        kdim=None,
        vdim=None,
        need_weights=False,
        weight_attr=None,
        bias_attr=None,
        fuse=False,
        mesh_idx=None,
58 59
        use_new_recompute=False,
        recompute_granularity="full",
60
    ):
61
        super().__init__()
62 63 64 65 66 67 68 69
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.need_weights = need_weights
        self.fuse = fuse
        self.mesh_idx = mesh_idx
70 71 72
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

73
        self.head_dim = embed_dim // num_heads
74 75 76
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
77 78 79
        if self.fuse:
            assert self.kdim == embed_dim
            assert self.vdim == embed_dim
80 81 82
            self.qkv_proj = nn.Linear(
                embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr
            )
83
        else:
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
            self.q_proj = nn.Linear(
                embed_dim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.k_proj = nn.Linear(
                self.kdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
            self.v_proj = nn.Linear(
                self.vdim,
                embed_dim,
                weight_attr=weight_attr,
                bias_attr=bias_attr,
            )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr
        )
105 106 107

    def _fuse_prepare_qkv(self, query):
        mix_layer = self.qkv_proj(query)
108 109 110
        mix_layer = paddle.reshape_(
            mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]
        )
111 112 113 114 115 116
        mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
        q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
        return q, k, v

    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        """
C
co63oc 已提交
117
        Prepares linear projected queries, keys and values for usage of subsequent
118 119 120 121 122
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.
        """
        q = self.q_proj(query)
        if _global_parallel_strategy == "mp":
123 124 125
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "x"]
            )
126
        elif _global_parallel_strategy == "dp_mp":
127 128 129
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, [None, "y"]
            )
130
        elif _global_parallel_strategy == "mp_pp":
131 132 133
            auto.shard_tensor(
                self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
134
        elif _global_parallel_strategy == "dp_mp_pp":
135 136 137
            auto.shard_tensor(
                self.q_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)
        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)
        return (q, k, v) if use_cache is False else (q, k, v, cache)

    def compute_kv(self, key, value):
        """
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.
        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.
        """
        k = self.k_proj(key)
        if _global_parallel_strategy == "mp":
166 167 168
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "x"]
            )
169
        elif _global_parallel_strategy == "dp_mp":
170 171 172
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, [None, "y"]
            )
173
        elif _global_parallel_strategy == "mp_pp":
174 175 176
            auto.shard_tensor(
                self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
177
        elif _global_parallel_strategy == "dp_mp_pp":
178 179 180
            auto.shard_tensor(
                self.k_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
181 182
        v = self.v_proj(value)
        if _global_parallel_strategy == "mp":
183 184 185
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "x"]
            )
186
        elif _global_parallel_strategy == "dp_mp":
187 188 189
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, [None, "y"]
            )
190
        elif _global_parallel_strategy == "mp_pp":
191 192 193
            auto.shard_tensor(
                self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
194
        elif _global_parallel_strategy == "dp_mp_pp":
195 196 197
            auto.shard_tensor(
                self.v_proj.weight, DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]
            )
198 199 200 201 202 203 204 205
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v

    def gen_cache(self, key, value=None, type=Cache):
        """
C
co63oc 已提交
206
        Generates cache for `forward` usage in inference according to arguments.
207 208 209 210 211 212 213
        The generated cache is an instance of `MultiHeadAttention.Cache` or an
        instance of `MultiHeadAttention.StaticCache`.
        """
        if type == MultiHeadAttention.StaticCache:  # static_kv
            k, v = self.compute_kv(key, value)
            return self.StaticCache(k, v)
        elif value is None:  # incremental_state
214 215 216 217
            fill_shape = [-1, self.num_heads, 0, self.head_dim]
            fill_shape[0] = paddle.shape(key)[0].item()
            k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
            v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
218 219 220 221 222
            return self.Cache(k, v)
        else:
            # incremental_state with initial value, mainly for usage like UniLM
            return self.Cache(key, value)

223
    def core_attn(self, q, k, v, attn_mask):
K
kangguangli 已提交
224 225 226
        product = paddle.matmul(x=q, y=k, transpose_y=True)
        product = paddle.multiply(
            product,
227
            paddle.to_tensor([self.head_dim**-0.5], dtype=product.dtype),
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        )
        if attn_mask is not None:
            product = product + attn_mask
        weights = F.softmax(product)
        if self.dropout:
            weights = F.dropout(
                weights,
                self.dropout,
                training=self.training,
                mode="upscale_in_train",
            )
        out = tensor.matmul(weights, v)
        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        return out, weights

246 247 248
    def forward(
        self, query, key, value, attn_mask=None, use_cache=False, cache=None
    ):
249 250 251 252 253 254 255 256 257 258 259 260 261
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
        else:
262 263 264
            q, k, v, cache = self._prepare_qkv(
                query, key, value, use_cache, cache
            )
265 266 267 268 269 270

        if self.use_new_recompute and self.recompute_granularity == "core_attn":
            out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
        else:
            out, weights = self.core_attn(q, k, v, attn_mask)

271 272 273
        # project to output
        out = self.out_proj(out)
        if _global_parallel_strategy == "mp":
274 275 276
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["x", None]
            )
277
        elif _global_parallel_strategy == "dp_mp":
278 279 280
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh, ["y", None]
            )
281
        elif _global_parallel_strategy == "mp_pp":
282 283 284
            auto.shard_tensor(
                self.out_proj.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
285
        elif _global_parallel_strategy == "dp_mp_pp":
286 287 288 289 290
            auto.shard_tensor(
                self.out_proj.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
291

292 293 294 295 296 297 298 299 300 301 302 303 304
        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)


class TransformerDecoder(nn.Layer):
    """
    TransformerDecoder is a stack of N decoder layers.
    """

305 306 307 308 309 310 311 312 313
    def __init__(
        self,
        decoder_layers,
        num_layers,
        norm=None,
        hidden_size=None,
        use_new_recompute=False,
        recompute_granularity="full",
    ):
314
        super().__init__()
315 316 317 318

        self.num_layers = num_layers
        self.layers = decoder_layers
        self.norm = norm
319 320
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity
321
        if norm == "LayerNorm":
322 323 324 325 326
            self.norm = nn.LayerNorm(hidden_size)
        elif norm is not None:
            raise ValueError("Only support LayerNorm")
        self.checkpoints = []

327 328 329 330 331 332 333 334 335
    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        use_cache=False,
        cache=None,
    ):
336 337 338 339 340 341 342 343
        """
        Applies a stack of N Transformer decoder layers on inputs. If `norm` is
        provided, also applies layer normalization on the output of last decoder
        layer.
        """
        output = tgt
        new_caches = []
        self.checkpoints = []
344

345
        for i, mod in enumerate(self.layers):
346 347 348 349 350 351 352 353 354
            if _global_parallel_strategy == "pp":
                mod = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])
            elif _global_parallel_strategy == "dp_pp":
                mod = auto.shard_op(mod, DPPP_MESH_LIST[mod.mesh_idx])
            elif _global_parallel_strategy == "mp_pp":
                mod = auto.shard_op(mod, MPPP_MESH_LIST[mod.mesh_idx])
            elif _global_parallel_strategy == "dp_mp_pp":
                mod = auto.shard_op(mod, DPMPPP_MESH_LIST[mod.mesh_idx])

355 356 357
            if self.use_new_recompute and self.recompute_granularity == "full":
                mod = auto.recompute(mod)

358 359
            if cache is None:
                if use_cache:
360 361 362 363 364
                    output, new_cache = mod(
                        output,
                        memory,
                        tgt_mask=tgt_mask,
                        use_cache=use_cache,
365
                        cache=cache,
366
                    )
367 368 369 370 371 372 373 374 375 376 377
                    new_caches.append(new_cache)
                else:
                    output = mod(output, memory, tgt_mask, use_cache, cache)
            else:
                output, new_cache = mod(
                    output,
                    memory,
                    tgt_mask=tgt_mask,
                    use_cache=use_cache,
                    cache=cache[i],
                )
378
                new_caches.append(new_cache)
379 380 381 382

            if not self.use_new_recompute:
                self.checkpoints.append(output.name)

383 384 385 386 387 388 389 390 391 392 393
        if self.norm is not None:
            output = self.norm(output)
        return output if use_cache is False else (output, new_caches)

    def gen_cache(self, memory, do_zip=False):
        """
        Generates cache for `forward` usage. The generated cache is a list, and
        each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
        produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
        for more details. If `do_zip` is True, apply `zip` on these tuples to get
        a list with two elements.
394
        """
395 396 397 398 399 400 401 402 403 404 405 406
        cache = [layer.gen_cache(memory) for layer in self.layers]
        if do_zip:
            cache = list(zip(*cache))
        return cache


class TransformerDecoderLayer(nn.Layer):
    """
    The transformer decoder layer.
    It contains multiheadattention and some linear layers.
    """

407 408 409 410 411 412 413 414 415 416 417 418 419
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward,
        dropout=0.1,
        activation="gelu",
        attn_dropout=None,
        act_dropout=None,
        normalize_before=True,
        weight_attr=None,
        bias_attr=None,
        mesh_idx=None,
420 421
        use_new_recompute=False,
        recompute_granularity="full",
422
    ):
423 424 425 426
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3
        self.mesh_idx = mesh_idx
427
        super().__init__()
428 429 430
        attn_dropout = dropout if attn_dropout is None else attn_dropout
        act_dropout = dropout if act_dropout is None else act_dropout
        self.normalize_before = normalize_before
431 432 433
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

434 435
        weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
436

437 438 439 440 441 442 443
        self.self_attn = MultiHeadAttention(
            d_model,
            nhead,
            dropout=attn_dropout,
            weight_attr=weight_attrs[0],
            bias_attr=bias_attrs[0],
            mesh_idx=self.mesh_idx,
444 445
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
446 447 448 449 450 451 452
        )
        self.linear1 = nn.Linear(
            d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
        )
        self.linear2 = nn.Linear(
            dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]
        )
453 454 455 456 457 458 459 460 461 462
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
        self.activation = getattr(F, activation)

    def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
463 464 465 466 467 468

        if self.use_new_recompute and self.recompute_granularity == "full_attn":
            self_attn = auto.recompute(self.self_attn)
        else:
            self_attn = self.self_attn

469
        if use_cache is False:
470
            tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
471
        else:
472
            tgt, incremental_cache = self_attn(
473 474
                tgt, tgt, tgt, tgt_mask, use_cache, cache
            )
475

476 477 478 479 480 481 482
        tgt = residual + self.dropout1(tgt)
        if not self.normalize_before:
            tgt = self.norm1(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)
        if _global_parallel_strategy == "mp":
483 484 485
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "x"]
            )
486
        elif _global_parallel_strategy == "dp_mp":
487 488 489
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, "y"]
            )
490
        elif _global_parallel_strategy == "mp_pp":
491 492 493
            auto.shard_tensor(
                self.linear1.weight, MPPP_MESH_LIST[self.mesh_idx], [None, "x"]
            )
494
        if _global_parallel_strategy == "dp_mp_pp":
495 496 497 498 499
            auto.shard_tensor(
                self.linear1.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                [None, "y"],
            )
500

501
        if _global_parallel_strategy == "mp":
502 503 504
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["x", None]
            )
505
        elif _global_parallel_strategy == "dp_mp":
506 507 508
            auto.shard_tensor(
                self.linear2.weight, _global_process_mesh, ["y", None]
            )
509
        elif _global_parallel_strategy == "mp_pp":
510 511 512
            auto.shard_tensor(
                self.linear2.weight, MPPP_MESH_LIST[self.mesh_idx], ["x", None]
            )
513
        elif _global_parallel_strategy == "dp_mp_pp":
514 515 516 517 518
            auto.shard_tensor(
                self.linear2.weight,
                DPMPPP_MESH_LIST[self.mesh_idx],
                ["y", None],
            )
519
        tgt = self.dropout2(
520 521
            self.linear2(F.gelu(self.linear1(tgt), approximate=True))
        )
522 523 524 525 526 527
        tgt = residual + tgt
        if not self.normalize_before:
            tgt = self.norm2(tgt)
        return tgt if use_cache is False else (tgt, incremental_cache)

    def gen_cache(self, memory):
528 529 530
        incremental_cache = self.self_attn.gen_cache(
            memory, type=self.self_attn.Cache
        )
531 532 533 534 535 536 537 538
        return incremental_cache


class GPTEmbeddings(nn.Layer):
    """
    Include embeddings from word, position and token_type embeddings
    """

539 540 541 542 543 544 545 546 547
    def __init__(
        self,
        vocab_size,
        hidden_size=768,
        hidden_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
    ):
548
        super().__init__()
549 550 551
        self.word_embeddings = nn.Embedding(
            vocab_size,
            hidden_size,
552 553 554 555 556 557 558
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
559 560 561
        self.position_embeddings = nn.Embedding(
            max_position_embeddings,
            hidden_size,
562 563 564 565 566 567 568
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=initializer_range
                ),
            ),
        )
569 570 571 572 573 574 575
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
C
co63oc 已提交
576
        input_embeddings = self.word_embeddings(input_ids)
577
        if _global_parallel_strategy == "mp":
578 579 580
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["x", None]
            )
581
        elif _global_parallel_strategy == "dp_mp":
582 583 584
            auto.shard_tensor(
                self.word_embeddings.weight, _global_process_mesh, ["y", None]
            )
585
        elif _global_parallel_strategy == "mp_pp":
586 587 588
            auto.shard_tensor(
                self.word_embeddings.weight, MPPP_MESH_LIST[0], ["x", None]
            )
589
        elif _global_parallel_strategy == "dp_mp_pp":
590 591 592
            auto.shard_tensor(
                self.word_embeddings.weight, DPMPPP_MESH_LIST[0], ["y", None]
            )
593

594
        position_embeddings = self.position_embeddings(position_ids)
C
co63oc 已提交
595
        embeddings = input_embeddings + position_embeddings
596 597 598 599 600 601 602 603 604
        embeddings = self.dropout(embeddings)
        return embeddings


class GPTModel(nn.Layer):
    """
    The base model of gpt.
    """

605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
    def __init__(
        self,
        vocab_size=50304,
        hidden_size=1024,
        num_hidden_layers=24,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=512,
        type_vocab_size=16,
        initializer_range=0.02,
        pad_token_id=0,
        eos_token_id=7,
        bos_token_id=0,
        eol_token_id=3,
        pp_degree=None,
623 624
        use_new_recompute=False,
        recompute_granularity="full",
625
    ):
626
        super().__init__()
627 628 629 630
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
631 632 633
        self.use_new_recompute = use_new_recompute
        self.recompute_granularity = recompute_granularity

634
        self.layer_per_stage = None
635
        self.pipline_mode = pp_degree is not None and pp_degree > 1
636 637
        if self.pipline_mode:
            self.layer_per_stage = num_hidden_layers // pp_degree
638 639 640 641 642 643 644 645
        self.embeddings = GPTEmbeddings(
            vocab_size,
            hidden_size,
            hidden_dropout_prob,
            max_position_embeddings,
            type_vocab_size,
            self.initializer_range,
        )
646

647 648 649 650 651 652 653
        decoder_layers = nn.LayerList()
        for i in range(num_hidden_layers):
            mesh_index = None
            DecoderLayer = TransformerDecoderLayer
            if self.layer_per_stage is not None:
                mesh_index = i // self.layer_per_stage
            decoder_layers.append(
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
                DecoderLayer(
                    d_model=hidden_size,
                    nhead=num_attention_heads,
                    dim_feedforward=intermediate_size,
                    dropout=hidden_dropout_prob,
                    activation=hidden_act,
                    attn_dropout=attention_probs_dropout_prob,
                    act_dropout=hidden_dropout_prob,
                    weight_attr=paddle.ParamAttr(
                        initializer=nn.initializer.Normal(
                            mean=0.0, std=self.initializer_range
                        )
                    ),
                    bias_attr=None,
                    mesh_idx=mesh_index,
669 670
                    use_new_recompute=self.use_new_recompute,
                    recompute_granularity=self.recompute_granularity,
671 672
                )
            )
673

674
        Decoder = TransformerDecoder
675 676 677 678 679
        self.decoder = Decoder(
            decoder_layers,
            num_hidden_layers,
            norm="LayerNorm",
            hidden_size=hidden_size,
680 681
            use_new_recompute=self.use_new_recompute,
            recompute_granularity=self.recompute_granularity,
682
        )
683 684
        self.checkpoints = []

685 686 687 688 689 690 691 692
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        use_cache=False,
        cache=None,
    ):
693 694 695 696 697
        self.checkpoints = []
        if position_ids is None:
            past_length = 0
            if cache is not None:
                past_length = paddle.shape(cache[0].k)[-2]
698 699 700 701 702
            position_ids = paddle.arange(
                past_length,
                paddle.shape(input_ids)[-1] + past_length,
                dtype='int64',
            )
703
            position_ids = position_ids.unsqueeze(0)
704
            position_ids = paddle.expand_as(position_ids, input_ids)
705 706 707
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids
        )
708
        if _global_parallel_strategy == "pp":
709 710 711 712 713
            auto.shard_tensor(
                input_ids,
                PP_MESH_LIST[0],
                [None for i in range(len(input_ids.shape))],
            )
714
        if _global_parallel_strategy == "dp_pp":
715 716 717 718 719
            auto.shard_tensor(
                input_ids,
                DPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
720
        if _global_parallel_strategy == "dp_mp_pp":
721 722 723 724 725 726 727 728 729 730 731 732
            auto.shard_tensor(
                input_ids,
                DPMPPP_MESH_LIST[0],
                ["x"] + [None for i in range(len(input_ids.shape) - 1)],
            )
        encoder_outputs = self.decoder(
            embedding_output,
            memory=None,
            tgt_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
733 734
        if not self.use_new_recompute:
            self.checkpoints.extend(self.decoder.checkpoints)
735 736 737 738 739 740 741 742 743 744
        return encoder_outputs


class GPTForPretraining(nn.Layer):
    """
    The pretraining model of GPT.
    It returns some logits and cached_kvs.
    """

    def __init__(
745 746 747 748 749 750
        self,
        gpt,
        vocab_size=50304,
        hidden_size=768,
        initializer_range=0.02,
    ):
751
        super().__init__()
752 753
        self.gpt = gpt

754 755 756 757 758 759 760 761 762
    def forward(
        self,
        input_ids,
        position_ids=None,
        attention_mask=None,
        masked_positions=None,
        use_cache=False,
        cache=None,
    ):
763 764 765 766
        input_ids.stop_gradient = True
        position_ids.stop_gradient = True
        attention_mask.stop_gradient = True

767 768 769 770 771 772 773
        outputs = self.gpt(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            use_cache=use_cache,
            cache=cache,
        )
774 775 776 777
        if use_cache:
            encoder_outputs, cached_kvs = outputs[:2]
        else:
            encoder_outputs = outputs
778 779 780 781

        x = encoder_outputs
        w = self.gpt.embeddings.word_embeddings.weight

782
        mesh = None
783 784
        if _global_parallel_strategy == "pp":
            mesh = PP_MESH_LIST[-1]
785 786
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = [None for i in range(len(w.shape))]
787
        elif _global_parallel_strategy == "dp":
788 789 790
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
791
        elif _global_parallel_strategy == "mp":
792 793 794
            mesh = _global_process_mesh
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
795
        elif _global_parallel_strategy == "dp_mp":
796 797 798
            mesh = _global_process_mesh
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
799 800
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
801 802
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = [None for i in range(len(w.shape))]
803 804
        elif _global_parallel_strategy == "mp_pp":
            mesh = MPPP_MESH_LIST[-1]
805 806
            x_dims_mapping = [None for i in range(len(x.shape))]
            w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
807 808
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
809 810 811
            x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
            w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]

812 813 814 815 816 817 818 819
        with paddle.fluid.name_scope('skip_quant'):
            if mesh:
                matmul = auto.shard_op(
                    paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
                )
                logits = matmul(x, w, transpose_y=True)
            else:
                logits = paddle.matmul(x, w, transpose_y=True)
820

821 822 823 824 825 826 827 828 829 830 831 832 833
        if use_cache:
            return logits, cached_kvs
        else:
            return logits


class GPTPretrainingCriterion(nn.Layer):
    """
    Criterion for GPT.
    It calculates the final loss.
    """

    def __init__(self):
834
        super().__init__()
835 836 837
        self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")

    def forward(self, prediction_scores, masked_lm_labels, loss_mask):
838 839
        masked_lm_labels.stop_gradient = True
        loss_mask.stop_gradient = True
840

841
        mesh = None
842
        if _global_parallel_strategy == "dp":
843
            mesh = _global_process_mesh
844 845 846
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
847
        elif _global_parallel_strategy == "dp_mp":
848
            mesh = _global_process_mesh
849 850 851
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
852 853
        elif _global_parallel_strategy == "dp_pp":
            mesh = DPPP_MESH_LIST[-1]
854 855 856
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
857 858
        elif _global_parallel_strategy == "dp_mp_pp":
            mesh = DPMPPP_MESH_LIST[-1]
859 860 861
            dims_mapping = ["x"] + [
                None for i in range(len(loss_mask.shape) - 1)
            ]
862

863 864
        if mesh:
            auto.shard_tensor(loss_mask, mesh, dims_mapping)
865

866 867 868
        masked_lm_loss = self.loss_func(
            prediction_scores, masked_lm_labels.unsqueeze(2)
        )
869 870 871
        loss_mask = loss_mask.reshape([-1])
        masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
        total_loss = masked_lm_loss / loss_mask.sum()
Z
zhaoyingli 已提交
872
        return total_loss