bert.py 30.5 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import warnings
from typing import Optional, Tuple

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed.fleet.utils import recompute
from paddle.fluid.data_feeder import convert_dtype
from paddle.io import DataLoader, Dataset
from paddle.nn import MultiHeadAttention

try:
    from paddle.incubate.nn import FusedTransformerEncoderLayer
except ImportError:
    FusedTransformerEncoderLayer = None


VOCAB_SIZE = 30522


class Stack(object):
    def __init__(self, axis=0, dtype=None):
        self._axis = axis
        self._dtype = dtype

    def __call__(self, data):
        data = (
            np.stack(data, axis=self._axis).astype(self._dtype)
            if self._dtype
            else np.stack(data, axis=self._axis)
        )
        return data


def is_tensor(x):
    if isinstance(x, paddle.Tensor):
        return True

    return isinstance(x, np.ndarray)


class BertConfig:
    def __init__(self):
        self.attention_probs_dropout_prob = 0.1
        self.fuse = False
        self.hidden_act = 'gelu'
        self.hidden_dropout_prob = 0.1
        # Decrease config to speed up unittest
        # self.hidden_size = 768
        self.hidden_size = 60
        self.initializer_range = 0.02
        self.intermediate_size = 3072
        self.layer_norm_eps = 1e-12
        self.max_position_embeddings = 512
        self.model_type = 'bert'
        # self.num_attention_heads = 12
        self.num_attention_heads = 6
        # self.num_hidden_layers = 12
        self.num_hidden_layers = 6
        self.pad_token_id = 0
        self.paddlenlp_version = None
        self.pool_act = 'tanh'
        self.type_vocab_size = 2
        self.vocab_size = VOCAB_SIZE
        self.use_return_dict = False
        self.output_hidden_states = False
        self.output_attentions = False
        self.use_cache = False


class BertLMPredictionHead(nn.Layer):
    def __init__(self, config: BertConfig, embedding_weights=None):
        super(BertLMPredictionHead, self).__init__()

        self.transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = getattr(nn.functional, config.hidden_act)
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.decoder_weight = (
            self.create_parameter(
                shape=[config.vocab_size, config.hidden_size],
                dtype=self.transform.weight.dtype,
                is_bias=False,
            )
            if embedding_weights is None
            else embedding_weights
        )

        self.decoder_bias = self.create_parameter(
            shape=[config.vocab_size],
            dtype=self.decoder_weight.dtype,
            is_bias=True,
        )

    def forward(self, hidden_states, masked_positions=None):
        if masked_positions is not None:
            hidden_states = paddle.reshape(
                hidden_states, [-1, hidden_states.shape[-1]]
            )
            hidden_states = paddle.tensor.gather(
                hidden_states, masked_positions
            )
        # gather masked tokens might be more quick
        hidden_states = self.transform(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = (
            paddle.tensor.matmul(
                hidden_states, self.decoder_weight, transpose_y=True
            )
            + self.decoder_bias
        )
        return hidden_states


class BertPretrainingHeads(nn.Layer):
    def __init__(self, config: BertConfig, embedding_weights=None):
        super(BertPretrainingHeads, self).__init__()
        self.predictions = BertLMPredictionHead(config, embedding_weights)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output, masked_positions=None):
        prediction_scores = self.predictions(sequence_output, masked_positions)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class BertEmbeddings(nn.Layer):
    def __init__(self, config: BertConfig):
        super(BertEmbeddings, self).__init__()

        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        past_key_values_length: Optional[int] = None,
    ):

        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)

            position_ids = seq_length - ones
            if past_key_values_length is not None:
                position_ids += past_key_values_length
            position_ids.stop_gradient = True
        if token_type_ids is None:
            token_type_ids = paddle.zeros_like(input_ids, dtype="int64")

        input_embedings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = (
            input_embedings + position_embeddings + token_type_embeddings
        )
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertPooler(nn.Layer):
    def __init__(self, config: BertConfig):
        super(BertPooler, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.pool_act = config.pool_act

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        if self.pool_act == "tanh":
            pooled_output = self.activation(pooled_output)
        return pooled_output


class BertModel(nn.Layer):
210
    def __init__(self, config: BertConfig, to_static):
W
WangZhen 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
        super(BertModel, self).__init__()
        self.config = config
        self.pad_token_id = config.pad_token_id
        self.initializer_range = config.initializer_range
        self.embeddings = BertEmbeddings(config)
        if config.fuse and FusedTransformerEncoderLayer is None:
            warnings.warn(
                "FusedTransformerEncoderLayer is not supported by the running Paddle. "
                "The flag fuse_transformer will be ignored. Try Paddle >= 2.3.0"
            )
        self.fuse = config.fuse and FusedTransformerEncoderLayer is not None
        if self.fuse:
            self.encoder = nn.LayerList(
                [
                    FusedTransformerEncoderLayer(
                        config.hidden_size,
                        config.num_attention_heads,
                        config.intermediate_size,
                        dropout_rate=config.hidden_dropout_prob,
                        activation=config.hidden_act,
                        attn_dropout_rate=config.attention_probs_dropout_prob,
                        act_dropout_rate=0.0,
                    )
                    for _ in range(config.num_hidden_layers)
                ]
            )
        else:
            encoder_layer = nn.TransformerEncoderLayer(
                config.hidden_size,
                config.num_attention_heads,
                config.intermediate_size,
                dropout=config.hidden_dropout_prob,
                activation=config.hidden_act,
                attn_dropout=config.attention_probs_dropout_prob,
                act_dropout=0,
            )
            self.encoder = nn.TransformerEncoder(
                encoder_layer, config.num_hidden_layers
            )
250 251
            if to_static:
                self.encoder = paddle.jit.to_static(self.encoder)
W
WangZhen 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
        self.pooler = BertPooler(config)
        # self.apply(self.init_weights)

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        attention_mask: Optional[Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[Tensor]]] = None,
        use_cache: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.use_return_dict
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        use_cache = (
            use_cache if use_cache is not None else self.config.use_cache
        )

        past_key_values_length = None
        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
        if attention_mask is None:
            attention_mask = paddle.unsqueeze(
                (input_ids == self.pad_token_id).astype(
                    self.pooler.dense.weight.dtype
                )
                * -1e4,
                axis=[1, 2],
            )
            if past_key_values is not None:
                batch_size = past_key_values[0][0].shape[0]
                past_mask = paddle.zeros(
                    [batch_size, 1, 1, past_key_values_length],
                    dtype=attention_mask.dtype,
                )
                attention_mask = paddle.concat(
                    [past_mask, attention_mask], axis=-1
                )

        else:
            if attention_mask.ndim == 2:
                # attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
                attention_mask = attention_mask.unsqueeze(axis=[1, 2]).astype(
                    paddle.get_default_dtype()
                )
                attention_mask = (1.0 - attention_mask) * -1e4

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            past_key_values_length=past_key_values_length,
        )
        if self.fuse:
            assert (
                not output_attentions
            ), "Not support attentions output currently."
            assert (
                past_key_values is None
            ), "Not support past_key_values currently."
            hidden_states = embedding_output
            all_hidden_states = [] if output_hidden_states else None
            for layer in self.encoder:
                hidden_states = layer(hidden_states, attention_mask)
                if output_hidden_states:
                    all_hidden_states.append(hidden_states)
            pooled_output = self.pooler(hidden_states)

            return (
                (hidden_states, pooled_output, all_hidden_states)
                if output_hidden_states
                else (hidden_states, pooled_output)
            )
        else:
            self.encoder._use_cache = use_cache  # To be consistent with HF
            encoder_outputs = self.encoder(
                embedding_output,
                src_mask=attention_mask,
                cache=past_key_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            if isinstance(encoder_outputs, type(embedding_output)):
                sequence_output = encoder_outputs
                pooled_output = self.pooler(sequence_output)
                return (sequence_output, pooled_output)
            else:
                sequence_output = encoder_outputs[0]
                pooled_output = self.pooler(sequence_output)
                return (sequence_output, pooled_output) + encoder_outputs[1:]


class Bert(nn.Layer):
369
    def __init__(self, to_static):
W
WangZhen 已提交
370 371
        super(Bert, self).__init__()
        config = BertConfig()
372
        self.bert = BertModel(config, to_static)
W
WangZhen 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
        self.cls = BertPretrainingHeads(
            config,
            embedding_weights=self.bert.embeddings.word_embeddings.weight,
        )

        # self.apply(self.init_weights)

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        attention_mask: Optional[Tensor] = None,
        masked_positions: Optional[Tensor] = None,
        labels: Optional[Tensor] = None,
        next_sentence_label: Optional[Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        with paddle.static.amp.fp16_guard():
            outputs = self.bert(
                input_ids,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            sequence_output, pooled_output = outputs[:2]
            prediction_scores, seq_relationship_score = self.cls(
                sequence_output, pooled_output, masked_positions
            )

            total_loss = None
            if labels is not None and next_sentence_label is not None:
                loss_fct = paddle.nn.CrossEntropyLoss()
                masked_lm_loss = loss_fct(
                    prediction_scores.reshape(
                        (-1, prediction_scores.shape[-1])
                    ),
                    labels.reshape((-1,)),
                )
                next_sentence_loss = loss_fct(
                    seq_relationship_score.reshape((-1, 2)),
                    next_sentence_label.reshape((-1,)),
                )
                total_loss = masked_lm_loss + next_sentence_loss

            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return (
                ((total_loss,) + output) if total_loss is not None else output
            )


class BertPretrainingCriterion(paddle.nn.Layer):
    def __init__(self, vocab_size=VOCAB_SIZE):
        super(BertPretrainingCriterion, self).__init__()
        # CrossEntropyLoss is expensive since the inner reshape (copy)
        self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1)
        self.vocab_size = vocab_size

    def forward(
        self,
        prediction_scores,
        seq_relationship_score,
        masked_lm_labels,
        next_sentence_labels,
        masked_lm_scale,
    ):
        with paddle.static.amp.fp16_guard():
            masked_lm_loss = F.cross_entropy(
                prediction_scores,
                masked_lm_labels,
                reduction="none",
                ignore_index=-1,
            )
            masked_lm_loss = masked_lm_loss / masked_lm_scale
            next_sentence_loss = F.cross_entropy(
                seq_relationship_score, next_sentence_labels, reduction="none"
            )
        return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)


def layer_init_wrapper(func):
    @functools.wraps(func)
    def _impl(self, *args, **kwargs):
        enable_recompute = kwargs.pop("enable_recompute", False)
        func(self, *args, **kwargs)
        if paddle.in_dynamic_mode():
            self.enable_recompute = enable_recompute
        else:
            self.enable_recompute = False

    return _impl


def _convert_attention_mask(attn_mask, dtype):
    if attn_mask is not None and attn_mask.dtype != dtype:
        attn_mask_dtype = convert_dtype(attn_mask.dtype)
        if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
            attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
        else:
            attn_mask = paddle.cast(attn_mask, dtype)
    return attn_mask


def _transformer_encoder_layer_fwd(
    self, src, src_mask=None, cache=None, output_attentions=False
):
    self.self_attn.need_weights = output_attentions
    src_mask = _convert_attention_mask(src_mask, src.dtype)

    residual = src
    if self.normalize_before:
        src = self.norm1(src)

    attn_outputs = self.self_attn(src, src, src, src_mask, cache)
    if isinstance(attn_outputs, tuple):
        src = attn_outputs[0]
        outputs = attn_outputs[1:]
    else:
        src = attn_outputs
        outputs = None

    src = residual + self.dropout1(src)
    if not self.normalize_before:
        src = self.norm1(src)

    residual = src
    if self.normalize_before:
        src = self.norm2(src)
    src = self.linear2(self.dropout(self.activation(self.linear1(src))))
    src = residual + self.dropout2(src)
    if not self.normalize_before:
        src = self.norm2(src)

    return (
        src if outputs is None else ((src,) + outputs[::-1])
    )  # hidden_states, cache, attentions


def _transformer_decoder_layer_fwd(
    self,
    tgt,
    memory,
    tgt_mask=None,
    memory_mask=None,
    cache=None,
    output_attentions=False,
):
    residual = tgt

    # self attention
    self.self_attn.need_weights = output_attentions
    tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)

    if self.normalize_before:
        tgt = self.norm1(tgt)

    self_attn_outputs = self.self_attn(
        tgt, tgt, tgt, tgt_mask, cache[0] if cache else None
    )
    # self_attn_outputs = (tgt, attn_weights, incremental_cache) or only tgt
    if isinstance(self_attn_outputs, type(tgt)):
        tgt = self_attn_outputs
    else:
        tgt = self_attn_outputs[0]
        if output_attentions:
            self_attn_weights = self_attn_outputs[1]
        if cache:
            incremental_cache = self_attn_outputs[-1]

    tgt = residual + self.dropout1(tgt)
    if not self.normalize_before:
        tgt = self.norm1(tgt)

    residual = tgt

    # cross attention
    if memory is not None:
        self.cross_attn.need_weights = output_attentions
        memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

        if self.normalize_before:
            tgt = self.norm2(tgt)

        cross_attn_outputs = self.cross_attn(
            tgt, memory, memory, memory_mask, cache[1] if cache else None
        )
        if isinstance(cross_attn_outputs, type(tgt)):
            tgt = cross_attn_outputs
        else:
            tgt = cross_attn_outputs[0]
            if output_attentions:
                cross_attn_weights = cross_attn_outputs[1]
            if cache:
                static_cache = cross_attn_outputs[-1]

        tgt = residual + self.dropout2(tgt)
        if not self.normalize_before:
            tgt = self.norm2(tgt)

        residual = tgt

    if self.normalize_before:
        tgt = self.norm3(tgt)
    tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
    tgt = residual + self.dropout3(tgt)
    if not self.normalize_before:
        tgt = self.norm3(tgt)

    if not output_attentions and cache is None:
        return tgt
    else:
        outputs = (tgt,)
        if output_attentions:
            outputs += (
                self_attn_weights,
                cross_attn_weights if memory is not None else None,
            )
        if cache:
            outputs += (
                (
                    incremental_cache,
                    static_cache if memory is not None else None,
                ),
            )
        return outputs


def _transformer_encoder_fwd(
    self,
    src,
    src_mask=None,
    cache=None,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=False,
):
    src_mask = _convert_attention_mask(src_mask, src.dtype)

    output = src
    # To get cache from None when use_cache is True, which is compatible with HF
    # while HF requires decoder. The implementation here uses cache update in the
    # MultiHeadAttention not so efficiently, and maybe optimize it later.
    if cache is None and getattr(self, "_use_cache", False):
        cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers)
    # To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts
    # to True when cache is not None.
    new_caches = (
        [] if cache is not None and getattr(self, "_use_cache", True) else None
    )
    all_attentions = [] if output_attentions else None
    # NOTE: Also includes embeding output which is same as HF.
    all_hidden_states = [output] if output_hidden_states else None
    for i, mod in enumerate(self.layers):
        if self.enable_recompute:
            # Note: recompute do not support pass as **kwargs yet.
            layer_outputs = recompute(
                mod,
                output,
                src_mask,
                None
                if cache is None
                else cache[i]
                if isinstance(cache[i], MultiHeadAttention.Cache)
                else MultiHeadAttention.Cache(*cache[i]),
                output_attentions,
            )
        else:
            layer_outputs = mod(
                output,
                src_mask=src_mask,
                cache=None
                if cache is None
                else cache[i]
                if isinstance(cache[i], MultiHeadAttention.Cache)
                else MultiHeadAttention.Cache(*cache[i]),
                output_attentions=output_attentions,
            )

        if isinstance(layer_outputs, tuple):
            output = layer_outputs[0]
            outputs = layer_outputs[1:]
        else:
            output = layer_outputs
            outputs = None

        if output_hidden_states:
            all_hidden_states.append(output)
        if output_attentions:
            all_attentions.append(outputs[-1])
        if new_caches is not None:
            new_caches.append(
                outputs[0]
                if isinstance(cache[i], MultiHeadAttention.Cache)
                else (tuple(outputs[0]))
            )

    if self.norm is not None:
        output = self.norm(output)

        if output_hidden_states:
            all_hidden_states[-1] = output

    outputs = tuple(
        tuple(v) if isinstance(v, list) else v
        for v in [
            output,
            new_caches,
            all_hidden_states,
            all_attentions,
        ]
        if v is not None
    )
    if len(outputs) == 1:
        return output
    else:
        return outputs


def _transformer_decoder_fwd(
    self,
    tgt,
    memory=None,
    tgt_mask=None,
    memory_mask=None,
    cache=None,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=False,
):
    tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
    if memory is not None:
        memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

    new_caches = [] if cache else None
    all_hidden_states = [tgt] if output_hidden_states else None
    all_self_attns = [] if output_attentions else None
    all_cross_attns = [] if output_attentions else None

    for i, mod in enumerate(self.layers):
        if cache is None:
            if self.enable_recompute:
                outputs = recompute(
                    mod,
                    tgt,
                    memory,
                    tgt_mask,
                    memory_mask,
                    None,
                    output_attentions,
                )
            else:
                outputs = mod(
                    tgt,
                    memory,
                    tgt_mask=tgt_mask,
                    memory_mask=memory_mask,
                    cache=None,
                    output_attentions=output_attentions,
                )
        else:
            outputs = mod(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                cache=cache[i] if cache else None,
                output_attentions=output_attentions,
            )
        if isinstance(outputs, type(tgt)):
            tgt = outputs
        else:
            tgt = outputs[0]
        if cache:
            new_caches.append(outputs[-1])
        if output_attentions:
            all_self_attns.append(outputs[1])
            all_cross_attns.append(outputs[2])
        if output_hidden_states:
            all_hidden_states.append(tgt)

    if self.norm is not None:
        tgt = self.norm(tgt)
        if output_hidden_states:
            all_hidden_states[-1] = tgt

    if isinstance(outputs, type(tgt)):
        return tgt

    temp_list = [
        tgt,
        new_caches if cache else None,
        all_hidden_states,
        all_self_attns,
        all_cross_attns,
    ]
    return tuple(v for v in temp_list if v is not None)


# patches of paddle.nn.Transformer to get all hidden_states and attentions
paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd
paddle.nn.TransformerDecoderLayer.forward = _transformer_decoder_layer_fwd
paddle.nn.TransformerEncoder.forward = _transformer_encoder_fwd
paddle.nn.TransformerDecoder.forward = _transformer_decoder_fwd

_encoder_init = paddle.nn.TransformerEncoder.__init__
_decoder_init = paddle.nn.TransformerDecoder.__init__
paddle.nn.TransformerEncoder.__init__ = layer_init_wrapper(_encoder_init)
paddle.nn.TransformerDecoder.__init__ = layer_init_wrapper(_decoder_init)


class PretrainingDataset(Dataset):
    def __init__(self, input_file, max_pred_length):
        self.input_file = input_file
        self.max_pred_length = max_pred_length
        keys = [
            "input_ids",
            "input_mask",
            "segment_ids",
            "masked_lm_positions",
            "masked_lm_ids",
            "next_sentence_labels",
        ]
        self.inputs = np.load(input_file)
        self.inputs = [self.inputs[key] for key in keys]

    def __len__(self):
        "Denotes the total number of samples"
        return len(self.inputs[0])

    def __getitem__(self, index):

        [
            input_ids,
            input_mask,
            segment_ids,
            masked_lm_positions,
            masked_lm_ids,
            next_sentence_labels,
        ] = [
            input[index].astype(np.int64)
            if indice < 5
            else np.asarray(input[index].astype(np.int64))
            for indice, input in enumerate(self.inputs)
        ]
        # TODO: whether to use reversed mask by changing 1s and 0s to be
        # consistent with nv bert
        input_mask = (
            1
            - np.reshape(
                input_mask.astype(np.float32), [1, 1, input_mask.shape[0]]
            )
        ) * -1e9

        index = self.max_pred_length
        # store number of  masked tokens in index
        # outputs of torch.nonzero diff with that of numpy.nonzero by zip
        padded_mask_indices = (masked_lm_positions == 0).nonzero()[0]
        if len(padded_mask_indices) != 0:
            index = padded_mask_indices[0].item()
        else:
            index = self.max_pred_length
        # masked_lm_labels = np.full(input_ids.shape, -1, dtype=np.int64)
        # masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
        masked_lm_labels = masked_lm_ids[:index]
        masked_lm_positions = masked_lm_positions[:index]
        # softmax_with_cross_entropy enforce last dim size equal 1
        masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1)
        next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1)
        return [
            input_ids,
            segment_ids,
            input_mask,
            masked_lm_positions,
            masked_lm_labels,
            next_sentence_labels,
        ]


def create_pretraining_dataset(
    input_file, max_pred_length, shared_list, batch_size, worker_init
):
    train_data = PretrainingDataset(
        input_file=input_file, max_pred_length=max_pred_length
    )
    # files have been sharded, no need to dispatch again
    train_batch_sampler = paddle.io.BatchSampler(
        train_data, batch_size=batch_size, shuffle=True
    )

    # DataLoader cannot be pickled because of its place.
    # If it can be pickled, use global function instead of lambda and use
    # ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch.
    def _collate_data(data, stack_fn=Stack()):
        num_fields = len(data[0])
        out = [None] * num_fields
        # input_ids, segment_ids, input_mask, masked_lm_positions,
        # masked_lm_labels, next_sentence_labels, mask_token_num
        for i in (0, 1, 2, 5):
            out[i] = stack_fn([x[i] for x in data])
        _, seq_length = out[0].shape
        size = sum(len(x[3]) for x in data)
        # Padding for divisibility by 8 for fp16 or int8 usage
        if size % 8 != 0:
            size += 8 - (size % 8)
        # masked_lm_positions
        # Organize as a 1D tensor for gather or use gather_nd
        out[3] = np.full(size, 0, dtype=np.int32)
        # masked_lm_labels
        out[4] = np.full([size, 1], -1, dtype=np.int64)
        mask_token_num = 0
        for i, x in enumerate(data):
            for j, pos in enumerate(x[3]):
                out[3][mask_token_num] = i * seq_length + pos
                out[4][mask_token_num] = x[4][j]
                mask_token_num += 1
        # mask_token_num
        out.append(np.asarray([mask_token_num], dtype=np.float32))
        return out

    train_data_loader = DataLoader(
        dataset=train_data,
        batch_sampler=train_batch_sampler,
        collate_fn=_collate_data,
        num_workers=0,
        worker_init_fn=worker_init,
        return_list=True,
    )
    return train_data_loader


if __name__ == '__main__':
    bert = Bert()