modeling.py 25.8 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import Optional, Tuple
from collections import OrderedDict

import paddle
import paddle.nn as nn
import paddle.tensor as tensor
import paddle.nn.functional as F

from .. import PretrainedModel, register_base_model

__all__ = [
27 28 29 30
    'ElectraModel', 'ElectraForTotalPretraining', 'ElectraDiscriminator',
    'ElectraGenerator', 'ElectraClassificationHead',
    'ElectraForSequenceClassification', 'ElectraForTokenClassification',
    'ElectraPretrainingCriterion'
Z
Zeyu Chen 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
]


def get_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError("function {} not found in ACT2FN mapping {}".format(
            activation_string, list(ACT2FN.keys())))


def mish(x):
    return x * F.tanh(F.softplus(x))


def linear_act(x):
    return x


def swish(x):
    return x * F.sigmoid(x)


ACT2FN = {
    "relu": F.relu,
    "gelu": F.gelu,
    "tanh": F.tanh,
    "sigmoid": F.sigmoid,
    "mish": mish,
    "linear": linear_act,
    "swish": swish,
}


class ElectraEmbeddings(nn.Layer):
    """Construct the embeddings from word, position and token_type embeddings."""

68 69
    def __init__(self, vocab_size, embedding_size, hidden_dropout_prob,
                 max_position_embeddings, type_vocab_size):
Z
Zeyu Chen 已提交
70 71 72 73 74 75 76
        super(ElectraEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings,
                                                embedding_size)
        self.token_type_embeddings = nn.Embedding(type_vocab_size,
                                                  embedding_size)

77
        self.layer_norm = nn.LayerNorm(embedding_size, epsilon=1e-12)
Z
Zeyu Chen 已提交
78 79
        self.dropout = nn.Dropout(hidden_dropout_prob)

80
    def forward(self, input_ids, token_type_ids=None, position_ids=None):
Z
Zeyu Chen 已提交
81
        if position_ids is None:
82 83 84 85
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=1)
            position_ids = seq_length - ones
            position_ids.stop_gradient = True
Z
Zeyu Chen 已提交
86 87 88 89

        if token_type_ids is None:
            token_type_ids = paddle.zeros_like(input_ids, dtype="int64")

90
        input_embeddings = self.word_embeddings(input_ids)
Z
Zeyu Chen 已提交
91 92 93
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

94
        embeddings = input_embeddings + position_embeddings + token_type_embeddings
Z
Zeyu Chen 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class ElectraDiscriminatorPredictions(nn.Layer):
    """Prediction module for the discriminator, made up of two dense layers."""

    def __init__(self, hidden_size, hidden_act):
        super(ElectraDiscriminatorPredictions, self).__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dense_prediction = nn.Linear(hidden_size, 1)
        self.act = get_activation(hidden_act)

    def forward(self, discriminator_hidden_states):
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = self.act(hidden_states)
        logits = self.dense_prediction(hidden_states).squeeze()

        return logits


class ElectraGeneratorPredictions(nn.Layer):
    """Prediction module for the generator, made up of two dense layers."""

    def __init__(self, embedding_size, hidden_size, hidden_act):
        super(ElectraGeneratorPredictions, self).__init__()

124
        self.layer_norm = nn.LayerNorm(embedding_size)
Z
Zeyu Chen 已提交
125 126 127 128 129 130
        self.dense = nn.Linear(hidden_size, embedding_size)
        self.act = get_activation(hidden_act)

    def forward(self, generator_hidden_states):
        hidden_states = self.dense(generator_hidden_states)
        hidden_states = self.act(hidden_states)
131
        hidden_states = self.layer_norm(hidden_states)
Z
Zeyu Chen 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360

        return hidden_states


class ElectraPretrainedModel(PretrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "electra"
    model_config_file = "model_config.json"

    # pretrained general configuration
    gen_weight = 1.0
    disc_weight = 50.0
    tie_word_embeddings = True
    untied_generator_embeddings = False
    use_softmax_sample = True

    # model init configuration
    pretrained_init_configuration = {
        "electra-small-generator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 128,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 256,
            "initializer_range": 0.02,
            "intermediate_size": 1024,
            "max_position_embeddings": 512,
            "num_attention_heads": 4,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "electra-base-generator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 768,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 256,
            "initializer_range": 0.02,
            "intermediate_size": 1024,
            "max_position_embeddings": 512,
            "num_attention_heads": 4,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "electra-large-generator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 1024,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 256,
            "initializer_range": 0.02,
            "intermediate_size": 1024,
            "max_position_embeddings": 512,
            "num_attention_heads": 4,
            "num_hidden_layers": 24,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "electra-small-discriminator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 128,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 256,
            "initializer_range": 0.02,
            "intermediate_size": 1024,
            "max_position_embeddings": 512,
            "num_attention_heads": 4,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "electra-base-discriminator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 768,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 768,
            "initializer_range": 0.02,
            "intermediate_size": 3072,
            "max_position_embeddings": 512,
            "num_attention_heads": 12,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "electra-large-discriminator": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 1024,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 1024,
            "initializer_range": 0.02,
            "intermediate_size": 4096,
            "max_position_embeddings": 512,
            "num_attention_heads": 16,
            "num_hidden_layers": 24,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 30522
        },
        "chinese-electra-discriminator-small": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 128,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 256,
            "initializer_range": 0.02,
            "intermediate_size": 1024,
            "max_position_embeddings": 512,
            "num_attention_heads": 4,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 21128,
        },
        "chinese-electra-discriminator-base": {
            "attention_probs_dropout_prob": 0.1,
            "embedding_size": 768,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 768,
            "initializer_range": 0.02,
            "intermediate_size": 3072,
            "max_position_embeddings": 512,
            "num_attention_heads": 12,
            "num_hidden_layers": 12,
            "pad_token_id": 0,
            "type_vocab_size": 2,
            "vocab_size": 21128,
        },
    }
    resource_files_names = {"model_state": "model_state.pdparams"}
    pretrained_resource_files_map = {
        "model_state": {
            "electra-small-generator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-small-generator.pdparams",
            "electra-base-generator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-base-generator.pdparams",
            "electra-large-generator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-large-generator.pdparams",
            "electra-small-discriminator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-small-discriminator.pdparams",
            "electra-base-discriminator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-base-discriminator.pdparams",
            "electra-large-discriminator":
            "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-large-discriminator.pdparamss",
            "chinese-electra-discriminator-small":
            "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-small/chinese-electra-discriminator-small.pdparams",
            "chinese-electra-discriminator-base":
            "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-base/chinese-electra-discriminator-base.pdparams",
        }
    }

    def init_weights(self):
        """
        Initializes and tie weights if needed.
        """
        # Initialize weights
        self.apply(self._init_weights)
        # Tie weights if needed
        self.tie_weights()

    def tie_weights(self):
        """
        Tie the weights between the input embeddings and the output embeddings.
        """
        if hasattr(self, "get_output_embeddings") and hasattr(
                self, "get_input_embeddings"):
            output_embeddings = self.get_output_embeddings()
            if output_embeddings is not None:
                self._tie_or_clone_weights(output_embeddings,
                                           self.get_input_embeddings())

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.set_value(
                paddle.tensor.normal(
                    mean=0.0,
                    std=self.initializer_range
                    if hasattr(self, "initializer_range") else
                    self.electra.config["initializer_range"],
                    shape=module.weight.shape))
        elif isinstance(module, nn.LayerNorm):
            module.bias.set_value(paddle.zeros_like(module.bias))
            module.weight.set_value(paddle.full_like(module.weight, 1.0))
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.set_value(paddle.zeros_like(module.bias))

    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
        """Tie or clone module weights"""
        if output_embeddings.weight.shape == input_embeddings.weight.shape:
            output_embeddings.weight = input_embeddings.weight
        elif output_embeddings.weight.shape == input_embeddings.weight.t(
        ).shape:
            output_embeddings.weight.set_value(input_embeddings.weight.t())
        else:
            raise ValueError(
                "when tie input/output embeddings, the shape of output embeddings: {}"
                "should be equal to shape of input embeddings: {}"
                "or should be equal to the shape of transpose input embeddings: {}".
                format(output_embeddings.weight.shape, input_embeddings.weight.
                       shape, input_embeddings.weight.t().shape))
        if getattr(output_embeddings, "bias", None) is not None:
            if output_embeddings.weight.shape[
                    -1] != output_embeddings.bias.shape[0]:
                raise ValueError(
                    "the weight lase shape: {} of output_embeddings is not equal to the bias shape: {}"
                    "please check output_embeddings configuration".format(
                        output_embeddings.weight.shape[
                            -1], output_embeddings.bias.shape[0]))


@register_base_model
class ElectraModel(ElectraPretrainedModel):
    def __init__(self, vocab_size, embedding_size, hidden_size,
                 num_hidden_layers, num_attention_heads, intermediate_size,
                 hidden_act, hidden_dropout_prob, attention_probs_dropout_prob,
361 362
                 max_position_embeddings, type_vocab_size, initializer_range,
                 pad_token_id):
Z
Zeyu Chen 已提交
363 364 365 366 367
        super(ElectraModel, self).__init__()
        self.pad_token_id = pad_token_id
        self.initializer_range = initializer_range
        self.embeddings = ElectraEmbeddings(
            vocab_size, embedding_size, hidden_dropout_prob,
368
            max_position_embeddings, type_vocab_size)
Z
Zeyu Chen 已提交
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390

        if embedding_size != hidden_size:
            self.embeddings_project = nn.Linear(embedding_size, hidden_size)

        encoder_layer = nn.TransformerEncoderLayer(
            hidden_size,
            num_attention_heads,
            intermediate_size,
            dropout=hidden_dropout_prob,
            activation=hidden_act,
            attn_dropout=attention_probs_dropout_prob,
            act_dropout=0)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

391 392 393 394 395
    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):
Z
Zeyu Chen 已提交
396 397

        if attention_mask is None:
398 399 400
            attention_mask = paddle.unsqueeze(
                (input_ids == self.pad_token_id).astype("float32") * -1e9,
                axis=[1, 2])
Z
Zeyu Chen 已提交
401

402
        embedding_output = self.embeddings(
Z
Zeyu Chen 已提交
403 404
            input_ids=input_ids,
            position_ids=position_ids,
405
            token_type_ids=token_type_ids)
Z
Zeyu Chen 已提交
406 407

        if hasattr(self, "embeddings_project"):
408
            embedding_output = self.embeddings_project(embedding_output)
Z
Zeyu Chen 已提交
409

410
        encoder_outputs = self.encoder(embedding_output, attention_mask)
Z
Zeyu Chen 已提交
411

412
        return encoder_outputs
Z
Zeyu Chen 已提交
413 414


415
class ElectraDiscriminator(ElectraPretrainedModel):
Z
Zeyu Chen 已提交
416
    def __init__(self, electra):
417
        super(ElectraDiscriminator, self).__init__()
Z
Zeyu Chen 已提交
418 419 420 421 422 423 424

        self.electra = electra
        self.discriminator_predictions = ElectraDiscriminatorPredictions(
            self.electra.config["hidden_size"],
            self.electra.config["hidden_act"])
        self.init_weights()

425 426 427 428 429 430 431 432
    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        discriminator_sequence_output = self.electra(
            input_ids, token_type_ids, position_ids, attention_mask)
Z
Zeyu Chen 已提交
433 434 435

        logits = self.discriminator_predictions(discriminator_sequence_output)

436 437 438 439
        return logits


class ElectraGenerator(ElectraPretrainedModel):
Z
Zeyu Chen 已提交
440
    def __init__(self, electra):
441
        super(ElectraGenerator, self).__init__()
Z
Zeyu Chen 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466

        self.electra = electra
        self.generator_predictions = ElectraGeneratorPredictions(
            self.electra.config["embedding_size"],
            self.electra.config["hidden_size"],
            self.electra.config["hidden_act"])

        if not self.tie_word_embeddings:
            self.generator_lm_head = nn.Linear(
                self.electra.config["embedding_size"],
                self.electra.config["vocab_size"])
        else:
            self.generator_lm_head_bias = paddle.fluid.layers.create_parameter(
                shape=[self.electra.config["vocab_size"]],
                dtype='float32',
                is_bias=True)
        self.init_weights()

    def get_input_embeddings(self):
        return self.electra.embeddings.word_embeddings

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
467 468 469 470
                attention_mask=None):

        generator_sequence_output = self.electra(input_ids, token_type_ids,
                                                 position_ids, attention_mask)
Z
Zeyu Chen 已提交
471 472 473 474 475 476

        prediction_scores = self.generator_predictions(
            generator_sequence_output)
        if not self.tie_word_embeddings:
            prediction_scores = self.generator_lm_head(prediction_scores)
        else:
477
            prediction_scores = paddle.add(paddle.matmul(
Z
Zeyu Chen 已提交
478
                prediction_scores,
479 480 481 482 483
                self.get_input_embeddings().weight,
                transpose_y=True),
                                           self.generator_lm_head_bias)

        return prediction_scores
Z
Zeyu Chen 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499


# class ElectraClassificationHead and ElectraForSequenceClassification for fine-tuning
class ElectraClassificationHead(nn.Layer):
    """Head for sentence-level classification tasks."""

    def __init__(self, hidden_size, hidden_dropout_prob, num_labels):
        super(ElectraClassificationHead, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
500
        x = get_activation("gelu")(x)  # Electra paper used gelu here
Z
Zeyu Chen 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class ElectraForSequenceClassification(ElectraPretrainedModel):
    def __init__(self, electra, num_labels):
        super(ElectraForSequenceClassification, self).__init__()
        self.num_labels = num_labels
        self.electra = electra
        self.classifier = ElectraClassificationHead(
            self.electra.config["hidden_size"],
            self.electra.config["hidden_dropout_prob"], self.num_labels)

        self.init_weights()

517 518 519 520 521 522 523 524
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        sequence_output = self.electra(input_ids, token_type_ids, position_ids,
                                       attention_mask)
Z
Zeyu Chen 已提交
525

526
        logits = self.classifier(sequence_output)
Z
Zeyu Chen 已提交
527

528
        return logits
Z
Zeyu Chen 已提交
529 530 531 532 533 534 535 536 537 538 539 540


class ElectraForTokenClassification(ElectraPretrainedModel):
    def __init__(self, electra, num_labels):
        super(ElectraForTokenClassification, self).__init__()
        self.num_labels = num_labels
        self.electra = electra
        self.dropout = nn.Dropout(self.electra.config["hidden_dropout_prob"])
        self.classifier = nn.Linear(self.electra.config["hidden_size"],
                                    self.num_labels)
        self.init_weights()

541 542 543 544 545 546 547 548 549 550 551 552 553
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        sequence_output = self.electra(input_ids, token_type_ids, position_ids,
                                       attention_mask)

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        return logits
Z
Zeyu Chen 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576


class ElectraForTotalPretraining(ElectraPretrainedModel):
    def __init__(self, generator, discriminator):
        super(ElectraForTotalPretraining, self).__init__()

        self.generator = generator
        self.discriminator = discriminator
        self.initializer_range = discriminator.electra.initializer_range
        self.init_weights()

    def get_input_embeddings(self):
        if not self.untied_generator_embeddings:
            return self.generator.electra.embeddings.word_embeddings
        else:
            return None

    def get_output_embeddings(self):
        if not self.untied_generator_embeddings:
            return self.discriminator.electra.embeddings.word_embeddings
        else:
            return None

577 578 579
    def get_discriminator_inputs(self, inputs, raw_inputs, gen_logits,
                                 gen_labels, use_softmax_sample):
        """Sample from the generator to create discriminator input."""
Z
Zeyu Chen 已提交
580
        # get generator token result
581
        sampled_tokens = (self.sample_from_softmax(gen_logits,
Z
Zeyu Chen 已提交
582 583 584
                                                   use_softmax_sample)).detach()
        sampled_tokids = paddle.argmax(sampled_tokens, axis=-1)
        # update token only at mask position
585
        # gen_labels : [B, L], L contains -100(unmasked) or token value(masked)
Z
Zeyu Chen 已提交
586
        # mask_positions : [B, L], L contains 0(unmasked) or 1(masked)
587 588 589
        umask_positions = paddle.zeros_like(gen_labels)
        mask_positions = paddle.ones_like(gen_labels)
        mask_positions = paddle.where(gen_labels == -100, umask_positions,
Z
Zeyu Chen 已提交
590
                                      mask_positions)
591
        updated_inputs = self.update_inputs(inputs, sampled_tokids,
Z
Zeyu Chen 已提交
592
                                            mask_positions)
593
        # use inputs and updated_input to get discriminator labels
Z
Zeyu Chen 已提交
594
        labels = mask_positions * (paddle.ones_like(inputs) - paddle.equal(
595 596
            updated_inputs, raw_inputs).astype("int32"))
        return updated_inputs, labels, sampled_tokids
Z
Zeyu Chen 已提交
597 598 599 600 601 602 603 604 605 606

    def sample_from_softmax(self, logits, use_softmax_sample=True):
        if use_softmax_sample:
            #uniform_noise = paddle.uniform(logits.shape, dtype="float32", min=0, max=1)
            uniform_noise = paddle.rand(logits.shape, dtype="float32")
            gumbel_noise = -paddle.log(-paddle.log(uniform_noise + 1e-9) + 1e-9)
        else:
            gumbel_noise = paddle.zeros_like(logits)
        # softmax_sample equal to sampled_tokids.unsqueeze(-1)
        softmax_sample = paddle.argmax(
607
            F.softmax(logits + gumbel_noise), axis=-1)
Z
Zeyu Chen 已提交
608 609 610
        # one hot
        return F.one_hot(softmax_sample, logits.shape[-1])

611
    def update_inputs(self, sequence, updates, positions):
Z
Zeyu Chen 已提交
612 613 614 615 616 617 618 619 620 621 622 623 624
        shape = sequence.shape
        assert (len(shape) == 2), "the dimension of inputs should be [B, L]"
        B, L = shape
        N = positions.shape[1]
        assert (
            N == L), "the dimension of inputs and mask should be same as [B, L]"

        updated_sequence = ((
            (paddle.ones_like(sequence) - positions) * sequence) +
                            (positions * updates))

        return updated_sequence

625 626 627 628 629 630 631 632
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None,
                raw_input_ids=None,
                gen_labels=None):

Z
Zeyu Chen 已提交
633
        assert (
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
            gen_labels is not None
        ), "gen_labels should not be None, please check DataCollatorForLanguageModeling"

        gen_logits = self.generator(input_ids, token_type_ids, position_ids,
                                    attention_mask)

        disc_inputs, disc_labels, generator_predict_tokens = self.get_discriminator_inputs(
            input_ids, raw_input_ids, gen_logits, gen_labels,
            self.use_softmax_sample)

        disc_logits = self.discriminator(disc_inputs, token_type_ids,
                                         position_ids, attention_mask)

        return gen_logits, disc_logits, disc_labels


class ElectraPretrainingCriterion(paddle.nn.Layer):
    def __init__(self, vocab_size, gen_weight, disc_weight):
        super(ElectraPretrainingCriterion, self).__init__()

        self.vocab_size = vocab_size
        self.gen_weight = gen_weight
        self.disc_weight = disc_weight
        self.gen_loss_fct = nn.CrossEntropyLoss(reduction='none')
        self.disc_loss_fct = nn.BCEWithLogitsLoss()

    def forward(self, generator_prediction_scores,
                discriminator_prediction_scores, generator_labels,
                discriminator_labels):
        # generator loss
        gen_loss = self.gen_loss_fct(
            paddle.reshape(generator_prediction_scores, [-1, self.vocab_size]),
            paddle.reshape(generator_labels, [-1]))
        # todo: we can remove 4 lines after when CrossEntropyLoss(reduction='mean') improved
        umask_positions = paddle.zeros_like(generator_labels).astype("float32")
        mask_positions = paddle.ones_like(generator_labels).astype("float32")
        mask_positions = paddle.where(generator_labels == -100, umask_positions,
                                      mask_positions)
        gen_loss = gen_loss.sum() / mask_positions.sum()

        # discriminator loss
        seq_length = discriminator_labels.shape[1]
        disc_loss = self.disc_loss_fct(
            paddle.reshape(discriminator_prediction_scores, [-1, seq_length]),
            discriminator_labels.astype("float32"))

        return self.gen_weight * gen_loss + self.disc_weight * disc_loss