diff --git a/PaddleNLP/examples/electra/run_glue.py b/PaddleNLP/examples/electra/run_glue.py index 93f6e5236bddd7770add0624084833234b3b36dc..f05e38d0c82da93ea87de28d644ea86fb839a257 100644 --- a/PaddleNLP/examples/electra/run_glue.py +++ b/PaddleNLP/examples/electra/run_glue.py @@ -25,25 +25,27 @@ from functools import partial import numpy as np import paddle from paddle.io import DataLoader +from paddle.metric import Metric, Accuracy, Precision, Recall from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.data.sampler import SamplerHelper from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) TASK_CLASSES = { - "cola": (GlueCoLA, paddle.metric.Accuracy), - "sst-2": (GlueSST2, paddle.metric.Accuracy), - "mrpc": (GlueMRPC, paddle.metric.Accuracy), - "sts-b": (GlueSTSB, paddle.metric.Accuracy), - "qqp": (GlueQQP, paddle.metric.Accuracy), - "mnli": (GlueMNLI, paddle.metric.Accuracy), - "qnli": (GlueQNLI, paddle.metric.Accuracy), - "rte": (GlueRTE, paddle.metric.Accuracy), + "cola": (GlueCoLA, Mcc), + "sst-2": (GlueSST2, Accuracy), + "mrpc": (GlueMRPC, AccuracyAndF1), + "sts-b": (GlueSTSB, PearsonAndSpearman), + "qqp": (GlueQQP, AccuracyAndF1), + "mnli": (GlueMNLI, Accuracy), + "qnli": (GlueQNLI, Accuracy), + "rte": (GlueRTE, Accuracy), } MODEL_CLASSES = { @@ -57,21 +59,17 @@ def set_seed(args): paddle.seed(args.seed + paddle.distributed.get_rank()) -def evaluate(model, loss_fct, metric, data_loader, return_dict): +def evaluate(model, loss_fct, metric, data_loader): model.eval() metric.reset() for batch in data_loader: input_ids, segment_ids, labels = batch - model_output = model(input_ids=input_ids, token_type_ids=segment_ids) - if not return_dict: - logits = model_output[0] - else: - logits = model_output.logits + logits = model(input_ids=input_ids, token_type_ids=segment_ids) loss = loss_fct(logits, labels) correct = metric.compute(logits, labels) metric.update(correct) - accu = metric.accumulate() - print("eval loss: %f, accu: %f, " % (loss.numpy(), accu), end='') + acc = metric.accumulate() + print("eval loss: %f, acc: %s, " % (loss.numpy(), acc), end='') model.train() @@ -218,9 +216,10 @@ def do_train(args): num_workers=0, return_list=True) + num_labels = 1 if train_dataset.get_labels() == None else len( + train_dataset.get_labels()) model = model_class.from_pretrained( - args.model_name_or_path, num_labels=len(train_dataset.get_labels())) - return_dict = model.return_dict + args.model_name_or_path, num_labels=num_labels) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) @@ -267,14 +266,14 @@ def do_train(args): tic_train = time.time() for epoch in range(args.num_train_epochs): for step, batch in enumerate(train_data_loader): + global_step += 1 input_ids, segment_ids, labels = batch - model_output = model( - input_ids=input_ids, token_type_ids=segment_ids) - if not return_dict: - logits = model_output[0] - else: - logits = model_output.logits + logits = model(input_ids=input_ids, token_type_ids=segment_ids) loss = loss_fct(logits, labels) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() if global_step % args.logging_steps == 0: print( "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" @@ -282,21 +281,15 @@ def do_train(args): paddle.distributed.get_rank(), loss, optimizer.get_lr(), args.logging_steps / (time.time() - tic_train))) tic_train = time.time() - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.clear_gradients() - if global_step > 1 and global_step % args.save_steps == 0: + if global_step % args.save_steps == 0: tic_eval = time.time() if args.task_name == "mnli": - evaluate(model, loss_fct, metric, dev_data_loader_matched, - return_dict) + evaluate(model, loss_fct, metric, dev_data_loader_matched) evaluate(model, loss_fct, metric, - dev_data_loader_mismatched, return_dict) + dev_data_loader_mismatched) print("eval done total : %s s" % (time.time() - tic_eval)) else: - evaluate(model, loss_fct, metric, dev_data_loader, - return_dict) + evaluate(model, loss_fct, metric, dev_data_loader) print("eval done total : %s s" % (time.time() - tic_eval)) if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: output_dir = os.path.join(args.output_dir, @@ -309,7 +302,6 @@ def do_train(args): model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - global_step += 1 def get_md5sum(file_path): @@ -374,7 +366,7 @@ if __name__ == "__main__": "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--learning_rate", - default=3e-4, + default=1e-4, type=float, help="The initial learning rate for Adam.") parser.add_argument( diff --git a/PaddleNLP/examples/electra/run_pretrain.py b/PaddleNLP/examples/electra/run_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..acca2e45baf954aaca6cf700f62c4ef86995d583 --- /dev/null +++ b/PaddleNLP/examples/electra/run_pretrain.py @@ -0,0 +1,552 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import collections +import itertools +import logging +import os +import io +import random +import time +from functools import partial +from concurrent.futures import ThreadPoolExecutor + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.io import DataLoader, Dataset + +from paddlenlp.transformers import ElectraForTotalPretraining, ElectraModel, ElectraPretrainingCriterion +from paddlenlp.transformers import ElectraDiscriminator, ElectraGenerator +from paddlenlp.transformers import ElectraTokenizer + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +MODEL_CLASSES = {"electra": (ElectraForTotalPretraining, ElectraTokenizer), } + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", + default="electra", + type=str, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default="electra-small", + type=str, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--input_dir", + default=None, + type=str, + required=True, + help="The input directory where the data will be read from.", ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="max length of each sequence") + parser.add_argument( + "--mask_prob", + default=0.15, + type=float, + help="the probability of one word to be mask") + parser.add_argument( + "--train_batch_size", + default=96, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--eval_batch_size", + default=96, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--learning_rate", + default=5e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.01, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--num_train_epochs", + default=4, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--warmup_steps", + default=10000, + type=int, + help="Linear warmup over warmup_steps.") + + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=1000, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--init_from_ckpt", + type=bool, + default=False, + help="Whether to load model checkpoint. if True, args.model_name_or_path must be dir store ckpt" + ) + parser.add_argument( + "--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--eager_run", type=bool, default=True, help="Use dygraph mode.") + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="number of gpus to use, 0 for cpu.") + args = parser.parse_args() + return args + + +def set_seed(args): + random.seed(args.seed + paddle.distributed.get_rank()) + np.random.seed(args.seed + paddle.distributed.get_rank()) + paddle.seed(args.seed + paddle.distributed.get_rank()) + + +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +class BookCorpus(paddle.io.Dataset): + """ + https://web.eecs.umich.edu/~lahiri/gutenberg_dataset.html + Args: + data_path (:obj:`str`) : The dataset file path, which contains train.tsv, dev.tsv and test.tsv. + tokenizer (:obj:`class PretrainedTokenizer`) : The tokenizer to split word and convert word to id. + max_seq_length (:obj:`int`) : max length for each sequence. + mode (:obj:`str`, `optional`, defaults to `train`): + It identifies the dataset mode (train, test or dev). + """ + + def __init__( + self, + data_path, + tokenizer, + max_seq_length, + mode='train', ): + if mode == 'train': + data_file = 'train.data' + elif mode == 'test': + data_file = 'test.data' + else: + data_file = 'dev.data' + + self.data_file = os.path.join(data_path, data_file) + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.raw_examples = self._read_file(self.data_file) + + def _read_file(self, input_file): + """ + Reads a text file. + + Args: + input_file (:obj:`str`) : The file to be read. + + Returns: + examples (:obj:`list`): All the input data. + """ + if not os.path.exists(input_file): + raise RuntimeError("The file {} is not found.".format(input_file)) + else: + with io.open(input_file, "r", encoding="UTF-8") as f: + examples = [] + for line in f.read().splitlines(): + if (len(line) > 0 and not line.isspace()): + tokens = self.tokenizer(line) + ids = self.tokenizer.convert_tokens_to_ids(tokens) + example = self.truncation_ids(ids, self.max_seq_length) + examples.append(example) + return examples + + def truncation_ids(self, ids, max_seq_length): + if len(ids) <= (max_seq_length - 2): + return ids + else: + return ids[:(max_seq_length - 2)] + + def __getitem__(self, idx): + return self.raw_examples[idx] + + def __len__(self): + return len(self.raw_examples) + + +class DataCollatorForElectra(object): + """ + pads, gets batch of tensors and preprocesses batches for masked language modeling + when dataloader num_worker > 0, this collator may trigger some bugs, for safe, be sure dataloader num_worker=0 + """ + + def __init__(self, + tokenizer, + max_seq_length, + mlm=True, + mlm_probability=0.15): + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.mlm = True + self.mlm_probability = mlm_probability + + def __call__(self, examples): + if self.mlm: + inputs, raw_inputs, labels = self.mask_tokens(examples) + return inputs, raw_inputs, labels + else: + raw_inputs, _ = self.add_special_tokens_and_set_maskprob( + examples, True, self.max_seq_length) + raw_inputs = self.tensorize_batch(raw_inputs, "int64") + inputs = raw_inputs.clone().detach() + labels = raw_inputs.clone().detach() + if self.tokenizer.pad_token is not None: + pad_token_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.pad_token) + labels[labels == pad_token_id] = -100 + return batch, raw_inputs, labels + + def tensorize_batch(self, examples, dtype): + if isinstance(examples[0], (list, tuple)): + examples = [paddle.to_tensor(e, dtype=dtype) for e in examples] + length_of_first = examples[0].shape[0] + are_tensors_same_length = all(x.shape[0] == length_of_first + for x in examples) + if are_tensors_same_length: + return paddle.stack(examples, axis=0) + else: + raise ValueError( + "the tensor in examples not have same shape, please check input examples" + ) + + def add_special_tokens_and_set_maskprob(self, inputs, truncation, + max_seq_length): + sep_token_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.sep_token) + pad_token_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.pad_token) + cls_token_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.cls_token) + full_inputs = [] + full_maskprob = [] + max_length = 0 + for ids in inputs: + if len(ids) > max_length: + max_length = len(ids) + max_length = min(max_length + 2, max_seq_length) + + for ids in inputs: + if len(ids) <= (max_length - 2): + padding_num = max_length - len(ids) - 2 + full_inputs.append([cls_token_id] + ids + [sep_token_id] + ( + [pad_token_id] * padding_num)) + full_maskprob.append([0] + ([self.mlm_probability] * len(ids)) + + [0] + ([0] * padding_num)) + else: + if truncation: + full_inputs.append([cls_token_id] + ids[:(max_length - 2)] + + [sep_token_id]) + full_maskprob.append([0] + ([self.mlm_probability] * ( + max_length - 2)) + [0]) + else: + full_inputs.append([cls_token_id] + ids + [sep_token_id]) + full_maskprob.append([0] + ([self.mlm_probability] * len( + ids)) + [0]) + return full_inputs, full_maskprob + + def mask_tokens(self, examples): + if self.tokenizer.mask_token is None: + raise ValueError( + "the tokenizer does not have mask_token, please check!") + mask_token_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.mask_token) + + raw_inputs, probability_matrix = self.add_special_tokens_and_set_maskprob( + examples, True, self.max_seq_length) + raw_inputs = self.tensorize_batch(raw_inputs, "int64") + probability_matrix = self.tensorize_batch(probability_matrix, "float32") + inputs = raw_inputs.clone() + labels = raw_inputs.clone() + + total_indices = paddle.bernoulli(probability_matrix).astype("bool") + unuse_labels = paddle.full(labels.shape, -100).astype("int64") + labels = paddle.where(total_indices, labels, unuse_labels) + + # 80% MASK + indices_mask = paddle.bernoulli(paddle.full(labels.shape, 0.8)).astype( + "bool").logical_and(total_indices) + masked_inputs = paddle.full(inputs.shape, mask_token_id).astype("int64") + inputs = paddle.where(indices_mask, masked_inputs, inputs) + + # 10% Random + indices_random = paddle.bernoulli(paddle.full( + labels.shape, 0.5)).astype("bool").logical_and( + total_indices).logical_and(indices_mask.logical_not()) + random_words = paddle.randint( + low=0, + high=self.tokenizer.vocab_size, + shape=labels.shape, + dtype="int64") + inputs = paddle.where(indices_random, random_words, inputs) + + # 10% Original + return inputs, raw_inputs, labels + + +def create_dataloader(dataset, + mode='train', + batch_size=1, + use_gpu=True, + data_collator=None): + """ + Creats dataloader. + + Args: + dataset(obj:`paddle.io.Dataset`): + Dataset instance. + mode(obj:`str`, optional, defaults to obj:`train`): + If mode is 'train', it will shuffle the dataset randomly. + batch_size(obj:`int`, optional, defaults to 1): + The sample number of a mini-batch. + use_gpu(obj:`bool`, optional, defaults to obj:`True`): + Whether to use gpu to run. + + Returns: + dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches. + """ + + #print("%s.data has batch_size : %s" % (mode, batch_size)) + if mode == 'train' and use_gpu: + sampler = paddle.io.DistributedBatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=True) + dataloader = paddle.io.DataLoader( + dataset, + batch_sampler=sampler, + return_list=True, + collate_fn=data_collator, + num_workers=0) + else: + shuffle = True if mode == 'train' else False + sampler = paddle.io.BatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=shuffle) + dataloader = paddle.io.DataLoader( + dataset, + batch_sampler=sampler, + return_list=True, + collate_fn=data_collator, + num_workers=0) + + return dataloader + + +def do_train(args): + paddle.enable_static() if not args.eager_run else None + paddle.set_device("gpu" if args.n_gpu else "cpu") + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + # Loads or initializes a model. + pretrained_models = list(tokenizer_class.pretrained_init_configuration.keys( + )) + if args.model_name_or_path in pretrained_models: + tokenizer = tokenizer_class.from_pretrained("./") + #tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + generator = ElectraGenerator( + ElectraModel(**ElectraModel.pretrained_init_configuration[ + args.model_name_or_path + "-generator"])) + discriminator = ElectraDiscriminator( + ElectraModel(**ElectraModel.pretrained_init_configuration[ + args.model_name_or_path + "-discriminator"])) + model = model_class(generator, discriminator) + else: + if os.path.isdir(args.model_name_or_path) and args.init_from_ckpt: + # load checkpoint + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + for file_id, file_name in model_class.resource_files_names.items(): + full_file_name = os.path.join(args.model_name_or_path, + file_name) + # to be write : load model ckpt file + else: + raise ValueError("initialize a model need identifier or the " + "path to a directory instead. The supported model " + "identifiers are as follows: {}".format( + model_class.pretrained_init_configuration.keys( + ))) + + criterion = ElectraPretrainingCriterion( + getattr(model.generator, + ElectraGenerator.base_model_prefix).config["vocab_size"], + model.gen_weight, model.disc_weight) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + # Loads dataset. + tic_load_data = time.time() + print("start load data : %s" % + (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) + train_dataset = BookCorpus( + data_path=args.input_dir, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + mode='train') + print("load data done, total : %s s" % (time.time() - tic_load_data)) + + # Reads data and generates mini-batches. + data_collator = DataCollatorForElectra( + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + mlm=True, + mlm_probability=args.mask_prob) + + train_data_loader = create_dataloader( + train_dataset, + batch_size=args.train_batch_size, + mode='train', + use_gpu=True if args.n_gpu else False, + data_collator=data_collator) + + num_training_steps = args.max_steps if args.max_steps > 0 else ( + len(train_data_loader) * args.num_train_epochs) + lr_scheduler = paddle.optimizer.lr.LambdaDecay( + args.learning_rate, + lambda current_step, num_warmup_steps=args.warmup_steps, + num_training_steps=num_training_steps: float( + current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps else max( + 0.0, + float(num_training_steps - current_step) / float( + max(1, num_training_steps - num_warmup_steps)))) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + + print("start train : %s" % + (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, raw_input_ids, gen_labels = batch + gen_logits, disc_logits, disc_labels = model( + input_ids=input_ids, + raw_input_ids=raw_input_ids, + gen_labels=gen_labels) + loss = criterion(gen_logits, disc_logits, gen_labels, disc_labels) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + #print("backward done, total %s s" % (time.time() - tic_train)) + #tic_train = time.time() + if global_step % args.logging_steps == 0: + print( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, optimizer.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0: + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d.pdparams" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + #model_to_save.save_pretrained(output_dir) + paddle.save(model.state_dict(), + os.path.join(output_dir, + "model_state.pdparams")) + tokenizer.save_pretrained(output_dir) + paddle.save(optimizer.state_dict(), + os.path.join(output_dir, "model_state.pdopt")) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + if args.n_gpu > 1: + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + else: + do_train(args) diff --git a/PaddleNLP/paddlenlp/metrics/__init__.py b/PaddleNLP/paddlenlp/metrics/__init__.py index 6a692461b287806c0fb5eeaf87884289dbfef4d7..86998090368b4c93a9ba510bce27dfe0cc434f9d 100644 --- a/PaddleNLP/paddlenlp/metrics/__init__.py +++ b/PaddleNLP/paddlenlp/metrics/__init__.py @@ -14,4 +14,5 @@ from .perplexity import Perplexity from .chunk import ChunkEvaluator -from .bleu import BLEU \ No newline at end of file +from .bleu import BLEU +from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman diff --git a/PaddleNLP/paddlenlp/metrics/glue.py b/PaddleNLP/paddlenlp/metrics/glue.py new file mode 100644 index 0000000000000000000000000000000000000000..aecf2cbc587d9c015467ea9274954f161357e2dd --- /dev/null +++ b/PaddleNLP/paddlenlp/metrics/glue.py @@ -0,0 +1,226 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import math +from functools import partial + +import numpy as np +import paddle +from paddle.metric import Metric, Accuracy, Precision, Recall + +__all__ = ['AccuracyAndF1', 'Mcc', 'PearsonAndSpearman'] + + +class AccuracyAndF1(Metric): + """ + Encapsulates Accuracy, Precision, Recall and F1 metric logic. + """ + + def __init__(self, + topk=(1, ), + pos_label=1, + name='acc_and_f1', + *args, + **kwargs): + super(AccuracyAndF1, self).__init__(*args, **kwargs) + self.topk = topk + self.pos_label = pos_label + self._name = name + self.acc = Accuracy(self.topk, *args, **kwargs) + self.precision = Precision(*args, **kwargs) + self.recall = Recall(*args, **kwargs) + self.reset() + + def compute(self, pred, label, *args): + self.label = label + self.preds_pos = paddle.nn.functional.softmax(pred)[:, self.pos_label] + return self.acc.compute(pred, label) + + def update(self, correct, *args): + self.acc.update(correct) + self.precision.update(self.preds_pos, self.label) + self.recall.update(self.preds_pos, self.label) + + def accumulate(self): + acc = self.acc.accumulate() + precision = self.precision.accumulate() + recall = self.recall.accumulate() + if precision == 0.0 or recall == 0.0: + f1 = 0.0 + else: + # 1/f1 = 1/2 * (1/precision + 1/recall) + f1 = (2 * precision * recall) / (precision + recall) + return ( + acc, + precision, + recall, + f1, + (acc + f1) / 2, ) + + def reset(self): + self.acc.reset() + self.precision.reset() + self.recall.reset() + self.label = None + self.preds_pos = None + + def name(self): + """ + Return name of metric instance. + """ + return self._name + + +class Mcc(Metric): + """ + Matthews correlation coefficient + https://en.wikipedia.org/wiki/Matthews_correlation_coefficient. + """ + + def __init__(self, name='mcc', *args, **kwargs): + super(Mcc, self).__init__(*args, **kwargs) + self._name = name + self.tp = 0 # true positive + self.fp = 0 # false positive + self.tn = 0 # true negative + self.fn = 0 # false negative + + def compute(self, pred, label, *args): + preds = paddle.argsort(pred, descending=True)[:, :1] + return (preds, label) + + def update(self, preds_and_labels): + preds = preds_and_labels[0] + preds = preds.numpy() + labels = preds_and_labels[1] + labels = labels.numpy().reshape(-1, 1) + sample_num = labels.shape[0] + for i in range(sample_num): + pred = preds[i] + label = labels[i] + if pred == 1: + if pred == label: + self.tp += 1 + else: + self.fp += 1 + else: + if pred == label: + self.tn += 1 + else: + self.fn += 1 + + def accumulate(self): + if self.tp == 0 or self.fp == 0 or self.tn == 0 or self.fn == 0: + mcc = 0.0 + else: + # mcc = (tp*tn-fp*fn)/ sqrt(tp+fp)(tp+fn)(tn+fp)(tn+fn)) + mcc = (self.tp * self.tn - self.fp * self.fn) / math.sqrt( + (self.tp + self.fp) * (self.tp + self.fn) * + (self.tn + self.fp) * (self.tn + self.fn)) + return (mcc, ) + + def reset(self): + self.tp = 0 # true positive + self.fp = 0 # false positive + self.tn = 0 # true negative + self.fn = 0 # false negative + + def name(self): + """ + Return name of metric instance. + """ + return self._name + + +class PearsonAndSpearman(Metric): + """ + Pearson correlation coefficient + https://en.wikipedia.org/wiki/Pearson_correlation_coefficient + Spearman's rank correlation coefficient + https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient. + """ + + def __init__(self, name='mcc', *args, **kwargs): + super(PearsonAndSpearman, self).__init__(*args, **kwargs) + self._name = name + self.preds = [] + self.labels = [] + + def update(self, preds_and_labels): + preds = preds_and_labels[0] + preds = np.squeeze(preds.numpy().reshape(-1, 1)).tolist() + labels = preds_and_labels[1] + labels = np.squeeze(labels.numpy().reshape(-1, 1)).tolist() + self.preds.append(preds) + self.labels.append(labels) + + def accumulate(self): + preds = [item for sublist in self.preds for item in sublist] + labels = [item for sublist in self.labels for item in sublist] + #import pdb; pdb.set_trace() + pearson = self.pearson(preds, labels) + spearman = self.spearman(preds, labels) + return ( + pearson, + spearman, + (pearson + spearman) / 2, ) + + def pearson(self, preds, labels): + n = len(preds) + #simple sums + sum1 = sum(float(preds[i]) for i in range(n)) + sum2 = sum(float(labels[i]) for i in range(n)) + #sum up the squares + sum1_pow = sum([pow(v, 2.0) for v in preds]) + sum2_pow = sum([pow(v, 2.0) for v in labels]) + #sum up the products + p_sum = sum([preds[i] * labels[i] for i in range(n)]) + + numerator = p_sum - (sum1 * sum2 / n) + denominator = math.sqrt( + (sum1_pow - pow(sum1, 2) / n) * (sum2_pow - pow(sum2, 2) / n)) + if denominator == 0: + return 0.0 + return numerator / denominator + + def spearman(self, preds, labels): + preds_rank = self.get_rank(preds) + labels_rank = self.get_rank(labels) + + total = 0 + n = len(preds) + for i in range(n): + total += pow((preds_rank[i] - labels_rank[i]), 2) + spearman = 1 - float(6 * total) / (n * (pow(n, 2) - 1)) + return spearman + + def get_rank(self, raw_list): + x = np.array(raw_list) + r_x = np.empty(x.shape, dtype=int) + y = np.argsort(-x) + for i, k in enumerate(y): + r_x[k] = i + 1 + return r_x + + def reset(self): + self.preds = [] + self.labels = [] + + def name(self): + """ + Return name of metric instance. + """ + return self._name diff --git a/PaddleNLP/paddlenlp/transformers/electra/modeling.py b/PaddleNLP/paddlenlp/transformers/electra/modeling.py index 7e4b1dab629498aeac83320991c24f392f7917ec..2ac0a9fea369a7416bf2cf4eb594b4bd1eb8adf7 100644 --- a/PaddleNLP/paddlenlp/transformers/electra/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/electra/modeling.py @@ -13,7 +13,6 @@ # limitations under the License. import os import time -from dataclasses import dataclass from typing import Optional, Tuple from collections import OrderedDict @@ -25,14 +24,10 @@ import paddle.nn.functional as F from .. import PretrainedModel, register_base_model __all__ = [ - 'ElectraModel', - 'ElectraForTotalPretraining', - 'ElectraForPretraining', - 'ElectraForMaskedLM', - 'ElectraClassificationHead', - 'ElectraForSequenceClassification', - 'ElectraForTokenClassification', - 'ElectraModelOutput', + 'ElectraModel', 'ElectraForTotalPretraining', 'ElectraDiscriminator', + 'ElectraGenerator', 'ElectraClassificationHead', + 'ElectraForSequenceClassification', 'ElectraForTokenClassification', + 'ElectraPretrainingCriterion' ] @@ -70,13 +65,8 @@ ACT2FN = { class ElectraEmbeddings(nn.Layer): """Construct the embeddings from word, position and token_type embeddings.""" - def __init__(self, - vocab_size, - embedding_size, - hidden_dropout_prob, - max_position_embeddings, - type_vocab_size, - layer_norm_eps=1e-12): + def __init__(self, vocab_size, embedding_size, hidden_dropout_prob, + max_position_embeddings, type_vocab_size): super(ElectraEmbeddings, self).__init__() self.word_embeddings = nn.Embedding(vocab_size, embedding_size) self.position_embeddings = nn.Embedding(max_position_embeddings, @@ -84,33 +74,24 @@ class ElectraEmbeddings(nn.Layer): self.token_type_embeddings = nn.Embedding(type_vocab_size, embedding_size) - self.layer_norm = nn.LayerNorm(embedding_size, epsilon=layer_norm_eps) + self.layer_norm = nn.LayerNorm(embedding_size, epsilon=1e-12) self.dropout = nn.Dropout(hidden_dropout_prob) - def forward(self, - input_ids=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None): - if input_ids is not None: - input_shape = input_ids.shape - else: - input_shape = inputs_embeds.shape[:-1] - - seq_length = input_shape[1] - + def forward(self, input_ids, token_type_ids=None, position_ids=None): if position_ids is None: - position_ids = paddle.arange(0, seq_length, dtype="int64") + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=1) + position_ids = seq_length - ones + position_ids.stop_gradient = True if token_type_ids is None: token_type_ids = paddle.zeros_like(input_ids, dtype="int64") - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = input_embeddings + position_embeddings + token_type_embeddings embeddings = self.layer_norm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -140,14 +121,14 @@ class ElectraGeneratorPredictions(nn.Layer): def __init__(self, embedding_size, hidden_size, hidden_act): super(ElectraGeneratorPredictions, self).__init__() - self.LayerNorm = nn.LayerNorm(embedding_size) + self.layer_norm = nn.LayerNorm(embedding_size) self.dense = nn.Linear(hidden_size, embedding_size) self.act = get_activation(hidden_act) def forward(self, generator_hidden_states): hidden_states = self.dense(generator_hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -165,16 +146,11 @@ class ElectraPretrainedModel(PretrainedModel): disc_weight = 50.0 tie_word_embeddings = True untied_generator_embeddings = False - untied_generator = True - output_hidden_states = False - output_attentions = False - return_dict = False use_softmax_sample = True # model init configuration pretrained_init_configuration = { "electra-small-generator": { - "architectures": ["ElectraForMaskedLM"], "attention_probs_dropout_prob": 0.1, "embedding_size": 128, "hidden_act": "gelu", @@ -182,9 +158,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 256, "initializer_range": 0.02, "intermediate_size": 1024, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 4, "num_hidden_layers": 12, "pad_token_id": 0, @@ -192,7 +166,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "electra-base-generator": { - "architectures": ["ElectraForMaskedLM"], "attention_probs_dropout_prob": 0.1, "embedding_size": 768, "hidden_act": "gelu", @@ -200,9 +173,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 256, "initializer_range": 0.02, "intermediate_size": 1024, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 4, "num_hidden_layers": 12, "pad_token_id": 0, @@ -210,7 +181,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "electra-large-generator": { - "architectures": ["ElectraForMaskedLM"], "attention_probs_dropout_prob": 0.1, "embedding_size": 1024, "hidden_act": "gelu", @@ -218,9 +188,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 256, "initializer_range": 0.02, "intermediate_size": 1024, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 4, "num_hidden_layers": 24, "pad_token_id": 0, @@ -228,7 +196,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "electra-small-discriminator": { - "architectures": ["ElectraForPretraining"], "attention_probs_dropout_prob": 0.1, "embedding_size": 128, "hidden_act": "gelu", @@ -236,9 +203,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 256, "initializer_range": 0.02, "intermediate_size": 1024, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 4, "num_hidden_layers": 12, "pad_token_id": 0, @@ -246,7 +211,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "electra-base-discriminator": { - "architectures": ["ElectraForPretraining"], "attention_probs_dropout_prob": 0.1, "embedding_size": 768, "hidden_act": "gelu", @@ -254,9 +218,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, @@ -264,7 +226,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "electra-large-discriminator": { - "architectures": ["ElectraForPretraining"], "attention_probs_dropout_prob": 0.1, "embedding_size": 1024, "hidden_act": "gelu", @@ -272,9 +233,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 1024, "initializer_range": 0.02, "intermediate_size": 4096, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 16, "num_hidden_layers": 24, "pad_token_id": 0, @@ -282,7 +241,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 30522 }, "chinese-electra-discriminator-small": { - "architectures": ["ElectraForPretraining"], "attention_probs_dropout_prob": 0.1, "embedding_size": 128, "hidden_act": "gelu", @@ -290,9 +248,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 256, "initializer_range": 0.02, "intermediate_size": 1024, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 4, "num_hidden_layers": 12, "pad_token_id": 0, @@ -300,7 +256,6 @@ class ElectraPretrainedModel(PretrainedModel): "vocab_size": 21128, }, "chinese-electra-discriminator-base": { - "architectures": ["ElectraForPretraining"], "attention_probs_dropout_prob": 0.1, "embedding_size": 768, "hidden_act": "gelu", @@ -308,9 +263,7 @@ class ElectraPretrainedModel(PretrainedModel): "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, - "layer_norm_eps": 1e-12, "max_position_embeddings": 512, - "model_type": "electra", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, @@ -399,177 +352,20 @@ class ElectraPretrainedModel(PretrainedModel): output_embeddings.weight.shape[ -1], output_embeddings.bias.shape[0])) - def get_extended_attention_mask(self, attention_mask, input_shape, place): - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - Arguments: - attention_mask (:obj:`paddle.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (:obj:`Tuple[int]`): - The shape of the input to the model. - place: (:obj:`paddle.Tensor.place`): - The place of the input to the model. - Returns: - :obj:`paddle.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - #extended_attention_mask = attention_mask[:, None, :, :] - extended_attention_mask = attention_mask.unsqueeze(1) - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})". - format(input_shape, attention_mask.shape)) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - #extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def get_head_mask(self, - head_mask, - num_hidden_layers, - is_attention_chunked=False): - """ - Prepare the head mask if needed. - Args: - head_mask (:obj:`paddle.Tensor` with shape :obj:`[num_heads]` - or :obj:`[num_hidden_layers x num_heads]`, `optional`): - The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). - num_hidden_layers (:obj:`int`): - The number of hidden layers in the model. - is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`): - Whether or not the attentions scores are computed by chunks or not. - Returns: - :obj:`paddle.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length] - or list with :obj:`[None]` for each layer. - """ - if head_mask is not None: - head_mask = self._convert_head_mask_to_5d(head_mask, - num_hidden_layers) - if is_attention_chunked is True: - head_mask = head_mask.unsqueeze(-1) - else: - head_mask = [None] * num_hidden_layers - - return head_mask - - def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): - """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( - -1).unsqueeze(-1) - head_mask = paddle.expand(head_mask, - [num_hidden_layers, -1, -1, -1, -1]) - elif head_mask.dim() == 2: - # We can specify head_mask for each layer - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - assert (head_mask.dim() == 5 - ), "head_mask.dim != 5, instead {head_mask.dim()}" - #head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility - return head_mask - - -@dataclass -class ElectraModelOutput(OrderedDict): - """ - Output type of :class:`ElectraPretrainedModel`. - Args: - loss (`optional`, returned when ``labels`` is provided, ``paddle.Tensor`` of shape :obj:`(1,)`): - Total loss of the ELECTRA objective. - logits (:obj:`paddle.Tensor` dtype=float32 of shape :obj:`(batch_size, sequence_length)`): - Prediction scores of the head (scores for each token before SoftMax). - hidden_states (:obj:`tuple(paddle.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``output_hidden_states=True``): - Tuple of :obj:`paddle.Tensor` dtype=float32 - (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(paddle.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``output_attentions=True``): - Tuple of :obj:`paddle.Tensor` dtype=float32 (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. - Attentions weights after the attention softmax, - used to compute the weighted average in the self-attention heads. - """ - loss = None - logits = None - hidden_states = None - attentions = None - - -ELECTRA_START_DOCSTRING = r""" - This model inherits from :class:`ElectraPretrainedModel`. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) - This model is also a Paddle `paddle.nn.Layer `__ - subclass. Use it as a regular Paddle Module and refer to the Padddle documentation for all matter related to - general usage and behavior. - Parameters: -""" - -ELECTRA_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`paddle.Tensor` of shape :obj:`({0})`): - Indices of input sequence tokens in the vocabulary. - attention_mask (:obj:`paddle.Tensor` dtype=float32 of shape :obj:`({0})`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - token_type_ids (:obj:`paddle.Tensor` dtype=int64 of shape :obj:`({0})`, `optional`): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, - 1]``: - - 0 corresponds to a `sentence A` token, - - 1 corresponds to a `sentence B` token. - position_ids (:obj:`paddle.Tensor` dtype=int64 of shape :obj:`({0})`, `optional`): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range ``[0, max_position_embeddings - 1]``. - head_mask (:obj:`paddle.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - inputs_embeds (:obj:`paddle.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - encoder_hidden_states (:obj:`paddle.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`paddle.Tensor` of shape :obj:`({0})`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. - return_dict (:obj:`bool`, `optional`): - Whether or not to return a :class:`ElectraModelOutput` instead of a plain tuple. -""" - @register_base_model class ElectraModel(ElectraPretrainedModel): def __init__(self, vocab_size, embedding_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, - max_position_embeddings, type_vocab_size, layer_norm_eps, - pad_token_id, initializer_range, model_type, architectures): + max_position_embeddings, type_vocab_size, initializer_range, + pad_token_id): super(ElectraModel, self).__init__() self.pad_token_id = pad_token_id self.initializer_range = initializer_range self.embeddings = ElectraEmbeddings( vocab_size, embedding_size, hidden_dropout_prob, - max_position_embeddings, type_vocab_size, layer_norm_eps) + max_position_embeddings, type_vocab_size) if embedding_size != hidden_size: self.embeddings_project = nn.Linear(embedding_size, hidden_size) @@ -592,70 +388,33 @@ class ElectraModel(ElectraPretrainedModel): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, ): - output_attentions = output_attentions if output_attentions is not None else self.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.shape - elif inputs_embeds is not None: - input_shape = inputs_embeds.shape[:-1] - else: - raise ValueError( - "You have to specify either input_ids or inputs_embeds") - - place = input_ids.place if input_ids is not None else inputs_embeds.place + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): if attention_mask is None: - attention_mask = paddle.ones(input_shape) - if token_type_ids is None: - token_type_ids = paddle.zeros(input_shape, dtype="int64") - - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, place) - #head_mask = self.get_head_mask(head_mask, self.num_hidden_layers) + attention_mask = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype("float32") * -1e9, + axis=[1, 2]) - hidden_states = self.embeddings( + embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds) + token_type_ids=token_type_ids) if hasattr(self, "embeddings_project"): - hidden_states = self.embeddings_project(hidden_states) + embedding_output = self.embeddings_project(embedding_output) - hidden_states = self.encoder( - hidden_states, - extended_attention_mask - #head_mask=head_mask, - #output_attentions=output_attentions, - #output_hidden_states=output_hidden_states, - #return_dict=return_dict, - ) + encoder_outputs = self.encoder(embedding_output, attention_mask) - return (hidden_states, ) + return encoder_outputs -class ElectraForPretraining(ElectraPretrainedModel): +class ElectraDiscriminator(ElectraPretrainedModel): def __init__(self, electra): - super(ElectraForPretraining, self).__init__() + super(ElectraDiscriminator, self).__init__() self.electra = electra self.discriminator_predictions = ElectraDiscriminatorPredictions( @@ -663,73 +422,23 @@ class ElectraForPretraining(ElectraPretrainedModel): self.electra.config["hidden_act"]) self.init_weights() - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, ): - r""" - labels (``paddle.Tensor`` dtype=ing64 of shape ``(batch_size, sequence_length)``, `optional`): - Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` - docstring) Indices should be in ``[0, 1]``: - - 0 indicates the token is an original token, - - 1 indicates the token was replaced. - """ - return_dict = return_dict if return_dict is not None else self.return_dict - - discriminator_hidden_states = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, ) - discriminator_sequence_output = discriminator_hidden_states[0] + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + + discriminator_sequence_output = self.electra( + input_ids, token_type_ids, position_ids, attention_mask) logits = self.discriminator_predictions(discriminator_sequence_output) - loss = None - if labels is not None: - loss_fct = nn.BCEWithLogitsLoss() - if attention_mask is not None: - active_loss = paddle.reshape( - attention_mask, - [-1, discriminator_sequence_output.shape[1]]) == 1 - active_logits = paddle.reshape( - logits, - [-1, discriminator_sequence_output.shape[1]])[active_loss] - active_labels = labels[active_loss] - loss = loss_fct(active_logits, active_labels.astype("float32")) - else: - loss = loss_fct( - paddle.reshape( - logits, [-1, discriminator_sequence_output.shape[1]]), - labels.astype("float32")) - - if not return_dict: - output = (logits, ) + discriminator_hidden_states[1:] - return ((loss, ) + output) if loss is not None else output - - return ElectraModelOutput( - loss=loss, - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, ) - - -class ElectraForMaskedLM(ElectraPretrainedModel): + return logits + + +class ElectraGenerator(ElectraPretrainedModel): def __init__(self, electra): - super(ElectraForMaskedLM, self).__init__() + super(ElectraGenerator, self).__init__() self.electra = electra self.generator_predictions = ElectraGeneratorPredictions( @@ -753,74 +462,25 @@ class ElectraForMaskedLM(ElectraPretrainedModel): def forward(self, input_ids=None, - attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - r""" - labels (:obj:`paddle.Tensor` dtype = int64 of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., vocab_size]`` - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - """ - assert (kwargs == {} - ), "Unexpected keyword arguments: {list(kwargs.keys())}." - return_dict = return_dict if return_dict is not None else self.return_dict - - generator_hidden_states = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, ) - generator_sequence_output = generator_hidden_states[0] + attention_mask=None): + + generator_sequence_output = self.electra(input_ids, token_type_ids, + position_ids, attention_mask) prediction_scores = self.generator_predictions( generator_sequence_output) if not self.tie_word_embeddings: prediction_scores = self.generator_lm_head(prediction_scores) else: - prediction_scores = F.linear( + prediction_scores = paddle.add(paddle.matmul( prediction_scores, - self.get_input_embeddings().weight.transpose([1, 0]), - self.generator_lm_head_bias) - - loss = None - # Masked language modeling softmax layer - if labels is not None: - loss_fct = nn.CrossEntropyLoss( - reduction='none') # -100 index = padding token - loss = loss_fct( - paddle.reshape(prediction_scores, [-1, self.vocab_size]), - paddle.reshape(labels, [-1])) - - umask_positions = paddle.zeros_like(labels).astype("float32") - mask_positions = paddle.ones_like(labels).astype("float32") - mask_positions = paddle.where(labels == -100, umask_positions, - mask_positions) - loss = loss.sum() / mask_positions.sum() - - if not return_dict: - output = (prediction_scores, ) + generator_hidden_states[1:] - return ((loss, ) + output) if loss is not None else output - - return ElectraModelOutput( - loss=loss, - logits=prediction_scores, - hidden_states=generator_hidden_states.hidden_states, - attentions=generator_hidden_states.attentions, ) + self.get_input_embeddings().weight, + transpose_y=True), + self.generator_lm_head_bias) + + return prediction_scores # class ElectraClassificationHead and ElectraForSequenceClassification for fine-tuning @@ -837,9 +497,7 @@ class ElectraClassificationHead(nn.Layer): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) - x = get_activation("gelu")( - x - ) # although BERT uses tanh here, it seems Electra authors used gelu here + x = get_activation("gelu")(x) # Electra paper used gelu here x = self.dropout(x) x = self.out_proj(x) return x @@ -856,50 +514,18 @@ class ElectraForSequenceClassification(ElectraPretrainedModel): self.init_weights() - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, ): - r""" - labels (:obj:`paddle.Tensor` dtype=int64 of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., - num_labels - 1]`. If :obj:`num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.return_dict - - discriminator_hidden_states = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, ) - - sequence_output = discriminator_hidden_states[0] - logits = self.classifier(sequence_output) + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + attention_mask=None): + + sequence_output = self.electra(input_ids, token_type_ids, position_ids, + attention_mask) - loss = None - if not return_dict: - output = (logits, ) + discriminator_hidden_states[1:] - return ((loss, ) + output) if loss is not None else output + logits = self.classifier(sequence_output) - return ElectraModelOutput( - loss=loss, - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, ) + return logits class ElectraForTokenClassification(ElectraPretrainedModel): @@ -912,51 +538,19 @@ class ElectraForTokenClassification(ElectraPretrainedModel): self.num_labels) self.init_weights() - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, ): - r""" - labels (:obj:`paddle.Tensor` dtype=int64 of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the token classification loss. - Indices should be in ``[0, ..., num_labels-1]``. - """ - return_dict = return_dict if return_dict is not None else self.return_dict - - discriminator_hidden_states = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, ) - discriminator_sequence_output = discriminator_hidden_states[0] - - discriminator_sequence_output = self.dropout( - discriminator_sequence_output) - logits = self.classifier(discriminator_sequence_output) - - loss = None - if not return_dict: - output = (logits, ) + discriminator_hidden_states[1:] - return ((loss, ) + output) if loss is not None else output - - return ElectraModelOutput( - loss=loss, - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, ) + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + attention_mask=None): + + sequence_output = self.electra(input_ids, token_type_ids, position_ids, + attention_mask) + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + return logits class ElectraForTotalPretraining(ElectraPretrainedModel): @@ -980,27 +574,26 @@ class ElectraForTotalPretraining(ElectraPretrainedModel): else: return None - def get_discriminator_inputs(self, inputs, raw_inputs, mlm_logits, masked, - use_softmax_sample): - """Sample from the generator to create corrupted input.""" + def get_discriminator_inputs(self, inputs, raw_inputs, gen_logits, + gen_labels, use_softmax_sample): + """Sample from the generator to create discriminator input.""" # get generator token result - sampled_tokens = (self.sample_from_softmax(mlm_logits, + sampled_tokens = (self.sample_from_softmax(gen_logits, use_softmax_sample)).detach() - #sampled_tokens = self.sample_from_softmax(mlm_logits) sampled_tokids = paddle.argmax(sampled_tokens, axis=-1) # update token only at mask position - # masked : [B, L], L contains -100(unmasked) or token value(masked) + # gen_labels : [B, L], L contains -100(unmasked) or token value(masked) # mask_positions : [B, L], L contains 0(unmasked) or 1(masked) - umask_positions = paddle.zeros_like(masked) - mask_positions = paddle.ones_like(masked) - mask_positions = paddle.where(masked == -100, umask_positions, + umask_positions = paddle.zeros_like(gen_labels) + mask_positions = paddle.ones_like(gen_labels) + mask_positions = paddle.where(gen_labels == -100, umask_positions, mask_positions) - updated_input = self.scatter_update(inputs, sampled_tokids, + updated_inputs = self.update_inputs(inputs, sampled_tokids, mask_positions) - # use inputs and updated_input to generate labels + # use inputs and updated_input to get discriminator labels labels = mask_positions * (paddle.ones_like(inputs) - paddle.equal( - updated_input, raw_inputs).astype("int32")) - return updated_input, labels, sampled_tokids + updated_inputs, raw_inputs).astype("int32")) + return updated_inputs, labels, sampled_tokids def sample_from_softmax(self, logits, use_softmax_sample=True): if use_softmax_sample: @@ -1010,24 +603,12 @@ class ElectraForTotalPretraining(ElectraPretrainedModel): else: gumbel_noise = paddle.zeros_like(logits) # softmax_sample equal to sampled_tokids.unsqueeze(-1) - ins_softmax = nn.Softmax(axis=-1) softmax_sample = paddle.argmax( - ins_softmax(logits + gumbel_noise), axis=-1) + F.softmax(logits + gumbel_noise), axis=-1) # one hot return F.one_hot(softmax_sample, logits.shape[-1]) - def scatter_update(self, sequence, updates, positions): - """Scatter-update a sequence. - Args: - sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor - updates: A tensor of size batch_size*seq_len(*depth) - positions: A [batch_size, n_positions] tensor - Returns: A tuple of two tensors. First is a [batch_size, seq_len] or - [batch_size, seq_len, depth] tensor of "sequence" with elements at - "positions" replaced by the values at "updates." Updates to index 0 are - ignored. If there are duplicated positions the update is only applied once. - Second is a [batch_size, seq_len] mask tensor of which inputs were updated. - """ + def update_inputs(self, sequence, updates, positions): shape = sequence.shape assert (len(shape) == 2), "the dimension of inputs should be [B, L]" B, L = shape @@ -1041,68 +622,59 @@ class ElectraForTotalPretraining(ElectraPretrainedModel): return updated_sequence - def forward( - self, - input_ids=None, - raw_input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, ): - r""" - labels (``paddle.Tensor`` dtype=int64 of shape ``(batch_size, sequence_length)``, `optional`): - Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` - docstring) Indices should be in ``[0, 1]``: - - 0 indicates the token is an original token, - - 1 indicates the token was replaced. - Returns: - """ - return_dict = return_dict if return_dict is not None else self.return_dict + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + attention_mask=None, + raw_input_ids=None, + gen_labels=None): + assert ( - labels is not None - ), "labels should not be None, please check DataCollatorForLanguageModeling" - - generator_output = self.generator( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - labels, - output_attentions, - output_hidden_states, - return_dict, ) - loss = generator_output[0] * self.gen_weight - logits = generator_output[1] - - discriminator_inputs, discriminator_labels, generator_predict_tokens = self.get_discriminator_inputs( - input_ids, raw_input_ids, logits, labels, self.use_softmax_sample) - - discriminator_output = self.discriminator( - discriminator_inputs, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - discriminator_labels, - output_attentions, - output_hidden_states, - return_dict, ) - loss += discriminator_output[0] * self.disc_weight - logits = discriminator_output[1] - - if not return_dict: - return ((loss, ) + (logits, )) - - return ElectraModelOutput( - loss=loss, - logits=logits, - hidden_states=generator_output.hidden_states, - attentions=generator_output.attentions, ) + gen_labels is not None + ), "gen_labels should not be None, please check DataCollatorForLanguageModeling" + + gen_logits = self.generator(input_ids, token_type_ids, position_ids, + attention_mask) + + disc_inputs, disc_labels, generator_predict_tokens = self.get_discriminator_inputs( + input_ids, raw_input_ids, gen_logits, gen_labels, + self.use_softmax_sample) + + disc_logits = self.discriminator(disc_inputs, token_type_ids, + position_ids, attention_mask) + + return gen_logits, disc_logits, disc_labels + + +class ElectraPretrainingCriterion(paddle.nn.Layer): + def __init__(self, vocab_size, gen_weight, disc_weight): + super(ElectraPretrainingCriterion, self).__init__() + + self.vocab_size = vocab_size + self.gen_weight = gen_weight + self.disc_weight = disc_weight + self.gen_loss_fct = nn.CrossEntropyLoss(reduction='none') + self.disc_loss_fct = nn.BCEWithLogitsLoss() + + def forward(self, generator_prediction_scores, + discriminator_prediction_scores, generator_labels, + discriminator_labels): + # generator loss + gen_loss = self.gen_loss_fct( + paddle.reshape(generator_prediction_scores, [-1, self.vocab_size]), + paddle.reshape(generator_labels, [-1])) + # todo: we can remove 4 lines after when CrossEntropyLoss(reduction='mean') improved + umask_positions = paddle.zeros_like(generator_labels).astype("float32") + mask_positions = paddle.ones_like(generator_labels).astype("float32") + mask_positions = paddle.where(generator_labels == -100, umask_positions, + mask_positions) + gen_loss = gen_loss.sum() / mask_positions.sum() + + # discriminator loss + seq_length = discriminator_labels.shape[1] + disc_loss = self.disc_loss_fct( + paddle.reshape(discriminator_prediction_scores, [-1, seq_length]), + discriminator_labels.astype("float32")) + + return self.gen_weight * gen_loss + self.disc_weight * disc_loss diff --git a/PaddleNLP/paddlenlp/transformers/electra/tokenizer.py b/PaddleNLP/paddlenlp/transformers/electra/tokenizer.py index 233bed9ea26160abc8fa4e45a46de3a6360e501d..4d4f03700c6f1e2d17008b944ac41cffe2af342b 100644 --- a/PaddleNLP/paddlenlp/transformers/electra/tokenizer.py +++ b/PaddleNLP/paddlenlp/transformers/electra/tokenizer.py @@ -47,47 +47,32 @@ class ElectraTokenizer(PretrainedTokenizer): resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained pretrained_resource_files_map = { "vocab_file": { - "electra-small-generator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-generator/vocab.txt", - "electra-base-generator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-generator/vocab.txt", - "electra-large-generator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-generator/vocab.txt", - "electra-small-discriminator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-discriminator/vocab.txt", - "electra-base-discriminator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-discriminator/vocab.txt", - "electra-large-discriminator": - "https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-discriminator/vocab.txt", - "chinese-electra-discriminator-base": - "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-base/vocab.txt", - "chinese-electra-discriminator-small": - "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-small/vocab.txt", + "electra-small": + "https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-vocab.txt", + "electra-base": + "https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-vocab.txt", + "electra-large": + "https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-vocab.txt", + "chinese-electra-base": + "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-base/vocab.txt", + "chinese-electra-small": + "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-small/vocab.txt", } } pretrained_init_configuration = { - "electra-small-generator": { + "electra-small": { "do_lower_case": True }, - "electra-base-generator": { + "electra-base": { "do_lower_case": True }, - "electra-large-generator": { + "electra-large": { "do_lower_case": True }, - "electra-small-discriminator": { + "chinese-electra-base": { "do_lower_case": True }, - "electra-base-discriminator": { - "do_lower_case": True - }, - "electra-large-discriminator": { - "do_lower_case": True - }, - "chinese-electra-discriminator-base": { - "do_lower_case": True - }, - "chinese-electra-discriminator-small": { + "chinese-electra-small": { "do_lower_case": True } } @@ -163,15 +148,12 @@ class ElectraTokenizer(PretrainedTokenizer): def num_special_tokens_to_add(self, pair=False): """ Returns the number of added tokens when encoding a sequence with special tokens. - Note: This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this inside your training loop. - Args: pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the number of added tokens in the case of a single sequence if set to False. - Returns: Number of tokens added to sequences """ @@ -190,13 +172,11 @@ class ElectraTokenizer(PretrainedTokenizer): :: - single sequence: ``[CLS] X [SEP]`` - pair of sequences: ``[CLS] A [SEP] B [SEP]`` - Args: token_ids_0 (:obj:`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. - Returns: :obj:`List[int]`: List of input_id with the appropriate special tokens. """ @@ -211,21 +191,16 @@ class ElectraTokenizer(PretrainedTokenizer): token_ids_1=None): """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. - A BERT sequence pair mask has the following format: :: - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | - If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). - Args: token_ids_0 (:obj:`List[int]`): List of IDs. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. - Returns: :obj:`List[int]`: List of token_type_id according to the given sequence(s). """ @@ -251,7 +226,6 @@ class ElectraTokenizer(PretrainedTokenizer): """ Returns a dictionary containing the encoded sequence or sequence pair and additional information: the mask for sequence classification and the overflowing elements if a ``max_seq_len`` is specified. - Args: text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using @@ -270,7 +244,6 @@ class ElectraTokenizer(PretrainedTokenizer): model's max length. truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`): String selected in the following options: - - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_seq_len starting from the longest one at each token (when there is a pair of input sequences) - 'only_first': Only truncate the first sequence @@ -288,10 +261,8 @@ class ElectraTokenizer(PretrainedTokenizer): Set to True to return overflowing token information (default False). return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): Set to True to return special tokens mask information (default False). - Return: A Dictionary of shape:: - { input_ids: list[int], position_ids: list[int] if return_position_ids is True (default) @@ -302,9 +273,7 @@ class ElectraTokenizer(PretrainedTokenizer): num_truncated_tokens: int if a ``max_seq_len`` is specified and return_overflowing_tokens is True special_tokens_mask: list[int] if return_special_tokens_mask is True } - With the fields: - - ``input_ids``: list of token ids to be fed to a model - ``position_ids``: list of token position ids to be fed to a model - ``segment_ids``: list of token type ids to be fed to a model