From 9ddc050ea15eeda4d91a7e160cb5da564b060527 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 5 Jan 2021 10:23:48 +0800 Subject: [PATCH] Support xpu for ernie and bert in pre-train/glue task (#5173) * add the xpu support the the paddlenlp * add the support xpu for the ernie and bert Co-authored-by: wawltor --- PaddleNLP/benchmark/bert/data.py | 4 +- PaddleNLP/benchmark/bert/run_glue.py | 22 ++- PaddleNLP/benchmark/bert/run_pretrain.py | 10 +- .../benchmark/bert/run_pretrain_single.py | 14 +- .../examples/language_model/bert/README.md | 4 +- .../examples/language_model/bert/run_glue.py | 29 +-- .../language_model/bert/run_pretrain.py | 169 +++++++++--------- .../paddlenlp/transformers/ernie/modeling.py | 75 ++++++-- 8 files changed, 199 insertions(+), 128 deletions(-) diff --git a/PaddleNLP/benchmark/bert/data.py b/PaddleNLP/benchmark/bert/data.py index 05e45c9a..206df047 100644 --- a/PaddleNLP/benchmark/bert/data.py +++ b/PaddleNLP/benchmark/bert/data.py @@ -45,7 +45,7 @@ def create_pretraining_dataset(input_file, size += 8 - (size % 8) # masked_lm_positions # Organize as a 1D tensor for gather or use gather_nd - out[3] = np.full(size, 0, dtype=np.int64) + out[3] = np.full(size, 0, dtype=np.int32) # masked_lm_labels out[4] = np.full([size, 1], -1, dtype=np.int64) mask_token_num = 0 @@ -78,7 +78,7 @@ def create_data_holder(args): input_mask = paddle.static.data( name="input_mask", shape=[-1, 1, 1, -1], dtype="float32") masked_lm_positions = paddle.static.data( - name="masked_lm_positions", shape=[-1], dtype="int64") + name="masked_lm_positions", shape=[-1], dtype="int32") masked_lm_labels = paddle.static.data( name="masked_lm_labels", shape=[-1, 1], dtype="int64") next_sentence_labels = paddle.static.data( diff --git a/PaddleNLP/benchmark/bert/run_glue.py b/PaddleNLP/benchmark/bert/run_glue.py index 3a8db53f..89211b20 100644 --- a/PaddleNLP/benchmark/bert/run_glue.py +++ b/PaddleNLP/benchmark/bert/run_glue.py @@ -28,6 +28,7 @@ from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueMNLI, from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.data.sampler import SamplerHelper from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer from paddlenlp.metrics import Mcc, PearsonAndSpearman from paddlenlp.utils.log import logger @@ -40,18 +41,16 @@ TASK_CLASSES = { "rte": (GlueRTE, Accuracy), } -MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), } +MODEL_CLASSES = { + "bert": (BertForSequenceClassification, BertTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer), +} def parse_args(): parser = argparse.ArgumentParser() # Required parameters - parser.add_argument( - "--select_device", - default="gpu", - type=str, - help="The device that selecting for the training, must be gpu/xpu.") parser.add_argument( "--task_name", default=None, @@ -140,6 +139,11 @@ def parse_args(): help="Save checkpoint every X updates steps.") parser.add_argument( "--seed", type=int, default=42, help="Random seed for initialization") + parser.add_argument( + "--select_device", + type=str, + default="gpu", + help="Device for selecting for the training.") args = parser.parse_args() return args @@ -157,10 +161,10 @@ def create_data_holder(task_name): return [input_ids, segment_ids, label] -def reset_program_state_dict(model, state_dict, pretrained_state_dict): +def reset_program_state_dict(args, model, state_dict, pretrained_state_dict): reset_state_dict = {} scale = model.initializer_range if hasattr(model, "initializer_range")\ - else model.bert.config["initializer_range"] + else getattr(model, args.model_type).config["initializer_range"] for n, p in state_dict.items(): if n not in pretrained_state_dict: dtype_str = "float32" @@ -410,7 +414,7 @@ def do_train(args): exe = paddle.static.Executor(place) exe.run(startup_program) state_dict = model.state_dict() - reset_state_dict = reset_program_state_dict(model, state_dict, + reset_state_dict = reset_program_state_dict(args, model, state_dict, pretrained_state_dict) paddle.static.set_program_state(main_program, reset_state_dict) diff --git a/PaddleNLP/benchmark/bert/run_pretrain.py b/PaddleNLP/benchmark/bert/run_pretrain.py index 1eaf5e9a..2ee194ee 100644 --- a/PaddleNLP/benchmark/bert/run_pretrain.py +++ b/PaddleNLP/benchmark/bert/run_pretrain.py @@ -38,11 +38,6 @@ MODEL_CLASSES = {"bert": (BertForPretraining, BertTokenizer)} def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--select_device", - default="gpu", - type=str, - help="The device that selecting for the training, must be gpu/xpu.") parser.add_argument( "--model_type", default=None, @@ -141,6 +136,11 @@ def parse_args(): type=float, default=1.0, help="The value of scale_loss for fp16.") + parser.add_argument( + "--select_device", + type=str, + default="gpu", + help="Device for selecting for the training.") args = parser.parse_args() return args diff --git a/PaddleNLP/benchmark/bert/run_pretrain_single.py b/PaddleNLP/benchmark/bert/run_pretrain_single.py index 9641ffbd..6faaf4bc 100644 --- a/PaddleNLP/benchmark/bert/run_pretrain_single.py +++ b/PaddleNLP/benchmark/bert/run_pretrain_single.py @@ -26,10 +26,14 @@ import distutils.util import paddle from paddle.io import DataLoader, Dataset from paddlenlp.transformers import BertForPretraining, BertModel, BertPretrainingCriterion -from paddlenlp.transformers import BertTokenizer +from paddlenlp.transformers import ErnieForPretraining, ErnieModel, ErniePretrainingCriterion +from paddlenlp.transformers import BertTokenizer, ErnieTokenizer from data import create_data_holder, create_pretraining_dataset -MODEL_CLASSES = {"bert": (BertForPretraining, BertTokenizer)} +MODEL_CLASSES = { + "bert": (BertForPretraining, BertTokenizer), + "ernie": (ErnieForPretraining, ErnieTokenizer) +} def parse_args(): @@ -141,7 +145,9 @@ def parse_args(): return args -def build_compiled_program(main_program, loss): +def build_compiled_program(args, main_program, loss): + if args.select_device == "xpu": + return main_program exec_strategy = paddle.static.ExecutionStrategy() exec_strategy.num_threads = 1 exec_strategy.num_iteration_per_drop_scope = 10000 @@ -250,7 +256,7 @@ def do_train(args): reset_state_dict = reset_program_state_dict(model, state_dict) paddle.static.set_program_state(main_program, reset_state_dict) # Construct the compiled program - main_program = build_compiled_program(main_program, loss) + main_program = build_compiled_program(args, main_program, loss) global_step = 0 tic_train = time.time() epoch = 0 diff --git a/PaddleNLP/examples/language_model/bert/README.md b/PaddleNLP/examples/language_model/bert/README.md index 660f5a47..b0ca74dd 100644 --- a/PaddleNLP/examples/language_model/bert/README.md +++ b/PaddleNLP/examples/language_model/bert/README.md @@ -74,7 +74,7 @@ python -u ./run_pretrain.py \ --logging_steps 1 \ --save_steps 20000 \ --max_steps 1000000 \ - --n_gpu 1 + --n_cards 1 ``` 其中参数释义如下: @@ -110,7 +110,7 @@ python -u ./run_glue.py \ --logging_steps 1 \ --save_steps 500 \ --output_dir ./tmp/ \ - --n_gpu 1 \ + --n_cards 1 ``` 其中参数释义如下: diff --git a/PaddleNLP/examples/language_model/bert/run_glue.py b/PaddleNLP/examples/language_model/bert/run_glue.py index 6917c223..b0712c63 100644 --- a/PaddleNLP/examples/language_model/bert/run_glue.py +++ b/PaddleNLP/examples/language_model/bert/run_glue.py @@ -48,7 +48,10 @@ TASK_CLASSES = { "rte": (GlueRTE, Accuracy), } -MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer)} +MODEL_CLASSES = { + "bert": (BertForSequenceClassification, BertTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer) +} def parse_args(): @@ -148,10 +151,16 @@ def parse_args(): parser.add_argument( "--seed", default=42, type=int, help="random seed for initialization") parser.add_argument( - "--n_gpu", + "--n_cards", default=1, type=int, - help="number of gpus to use, 0 for cpu.") + help="Number cards for the training, only support multi cards in the gpu." + ) + parser.add_argument( + "--select_device", + type=str, + default="gpu", + help="Device for selecting for the training.") args = parser.parse_args() return args @@ -265,7 +274,7 @@ def convert_example(example, def do_train(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.select_device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() @@ -402,10 +411,10 @@ def do_train(args): evaluate(model, loss_fct, metric, dev_data_loader) logger.info("eval done total : %s s" % (time.time() - tic_eval)) - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: - output_dir = os.path.join( - args.output_dir, "%s_ft_model_%d.pdparams" % - (args.task_name, global_step)) + if (not args.n_cards > 1) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "%s_ft_model_%d.pdparams" % + (args.task_name, global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) # Need better way to get inner model of DataParallel @@ -426,7 +435,7 @@ def print_arguments(args): if __name__ == "__main__": args = parse_args() print_arguments(args) - if args.n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + if args.n_cards > 1 and args.select_device == "gpu": + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_cards) else: do_train(args) diff --git a/PaddleNLP/examples/language_model/bert/run_pretrain.py b/PaddleNLP/examples/language_model/bert/run_pretrain.py index 0cbc8e00..6327b2e5 100644 --- a/PaddleNLP/examples/language_model/bert/run_pretrain.py +++ b/PaddleNLP/examples/language_model/bert/run_pretrain.py @@ -31,7 +31,8 @@ from paddle.io import DataLoader, Dataset from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.transformers import BertForPretraining, BertModel, BertPretrainingCriterion -from paddlenlp.transformers import BertTokenizer +from paddlenlp.transformers import ErnieForPretraining, ErnieModel, ErniePretrainingCriterion +from paddlenlp.transformers import BertTokenizer, ErnieTokenizer FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -39,6 +40,7 @@ logger = logging.getLogger(__name__) MODEL_CLASSES = { "bert": (BertForPretraining, BertTokenizer), + "ernie": (ErnieForPretraining, ErnieTokenizer) } @@ -50,8 +52,7 @@ def parse_args(): type=str, required=True, help="Model type selected in the list: " + - ", ".join(MODEL_CLASSES.keys()), - ) + ", ".join(MODEL_CLASSES.keys()), ) parser.add_argument( "--model_name_or_path", default=None, @@ -62,22 +63,19 @@ def parse_args(): sum([ list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values() - ], [])), - ) + ], [])), ) parser.add_argument( "--input_dir", default=None, type=str, required=True, - help="The input directory where the data will be read from.", - ) + help="The input directory where the data will be read from.", ) parser.add_argument( "--output_dir", default=None, type=str, required=True, - help= - "The output directory where the model predictions and checkpoints will be written.", + help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -90,58 +88,64 @@ def parse_args(): "--batch_size", default=8, type=int, - help="Batch size per GPU/CPU for training.", - ) - parser.add_argument("--learning_rate", - default=5e-5, - type=float, - help="The initial learning rate for Adam.") - parser.add_argument("--weight_decay", - default=0.0, - type=float, - help="Weight decay if we apply some.") - parser.add_argument("--adam_epsilon", - default=1e-8, - type=float, - help="Epsilon for Adam optimizer.") - parser.add_argument("--max_grad_norm", - default=1.0, - type=float, - help="Max gradient norm.") + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--num_train_epochs", default=3, type=int, - help="Total number of training epochs to perform.", - ) + help="Total number of training epochs to perform.", ) parser.add_argument( "--max_steps", default=-1, type=int, - help= - "If > 0: set total number of training steps to perform. Override num_train_epochs.", + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) - parser.add_argument("--warmup_steps", - default=0, - type=int, - help="Linear warmup over warmup_steps.") - - parser.add_argument("--logging_steps", - type=int, - default=500, - help="Log every X updates steps.") - parser.add_argument("--save_steps", - type=int, - default=500, - help="Save checkpoint every X updates steps.") - parser.add_argument("--seed", - type=int, - default=42, - help="random seed for initialization") - parser.add_argument("--n_gpu", - type=int, - default=1, - help="number of gpus to use, 0 for cpu.") + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") + + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--n_cards", + default=1, + type=int, + help="Number cards for the training, only support multi cards in the gpu." + ) + parser.add_argument( + "--select_device", + type=str, + default="gpu", + help="Device for selecting for the training.") args = parser.parse_args() return args @@ -163,12 +167,11 @@ class WorkerInitObj(object): def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, worker_init): - train_data = PretrainingDataset(input_file=input_file, - max_pred_length=max_pred_length) + train_data = PretrainingDataset( + input_file=input_file, max_pred_length=max_pred_length) # files have been sharded, no need to dispatch again - train_batch_sampler = paddle.io.BatchSampler(train_data, - batch_size=args.batch_size, - shuffle=True) + train_batch_sampler = paddle.io.BatchSampler( + train_data, batch_size=args.batch_size, shuffle=True) # DataLoader cannot be pickled because of its place. # If it can be pickled, use global function instead of lambda and use @@ -187,7 +190,7 @@ def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, size += 8 - (size % 8) # masked_lm_positions # Organize as a 1D tensor for gather or use gather_nd - out[3] = np.full(size, 0, dtype=np.int64) + out[3] = np.full(size, 0, dtype=np.int32) # masked_lm_labels out[4] = np.full([size, 1], -1, dtype=np.int64) mask_token_num = 0 @@ -200,12 +203,13 @@ def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, out.append(np.asarray([mask_token_num], dtype=np.float32)) return out - train_data_loader = DataLoader(dataset=train_data, - batch_sampler=train_batch_sampler, - collate_fn=_collate_data, - num_workers=0, - worker_init_fn=worker_init, - return_list=True) + train_data_loader = DataLoader( + dataset=train_data, + batch_sampler=train_batch_sampler, + collate_fn=_collate_data, + num_workers=0, + worker_init_fn=worker_init, + return_list=True) return train_data_loader, input_file @@ -237,8 +241,8 @@ class PretrainingDataset(Dataset): ] # TODO: whether to use reversed mask by changing 1s and 0s to be # consistent with nv bert - input_mask = (1 - np.reshape(input_mask.astype(np.float32), - [1, 1, input_mask.shape[0]])) * -1e9 + input_mask = (1 - np.reshape( + input_mask.astype(np.float32), [1, 1, input_mask.shape[0]])) * -1e9 index = self.max_pred_length # store number of masked tokens in index @@ -265,7 +269,7 @@ class PretrainingDataset(Dataset): def do_train(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.select_device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() @@ -281,8 +285,8 @@ def do_train(args): BertModel(**model_class.pretrained_init_configuration[ args.model_name_or_path])) criterion = BertPretrainingCriterion( - getattr(model, - BertForPretraining.base_model_prefix).config["vocab_size"]) + getattr(model, BertForPretraining.base_model_prefix).config[ + "vocab_size"]) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) @@ -316,8 +320,8 @@ def do_train(args): for epoch in range(args.num_train_epochs): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) - and "training" in f + if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in + f ] files.sort() num_files = len(files) @@ -328,10 +332,10 @@ def do_train(args): if paddle.distributed.get_world_size() > num_files: remainder = paddle.distributed.get_world_size() % num_files - data_file = files[ - (f_start_id * paddle.distributed.get_world_size() + - paddle.distributed.get_rank() + remainder * f_start_id) % - num_files] + data_file = files[( + f_start_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_start_id) % + num_files] else: data_file = files[(f_start_id * paddle.distributed.get_world_size() + paddle.distributed.get_rank()) % num_files] @@ -349,9 +353,10 @@ def do_train(args): if not single_file and f_id == f_start_id: continue if paddle.distributed.get_world_size() > num_files: - data_file = files[(f_id * paddle.distributed.get_world_size() + - paddle.distributed.get_rank() + - remainder * f_id) % num_files] + data_file = files[( + f_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_id) % + num_files] else: data_file = files[(f_id * paddle.distributed.get_world_size() + paddle.distributed.get_rank()) % num_files] @@ -374,7 +379,7 @@ def do_train(args): masked_lm_labels, next_sentence_labels, masked_lm_scale) if global_step % args.logging_steps == 0: - if (not args.n_gpu > 1 + if (not args.n_cards > 1 ) or paddle.distributed.get_rank() == 0: logger.info( "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" @@ -386,7 +391,7 @@ def do_train(args): lr_scheduler.step() optimizer.clear_gradients() if global_step % args.save_steps == 0: - if (not args.n_gpu > 1 + if (not args.n_cards > 1 ) or paddle.distributed.get_rank() == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) @@ -410,7 +415,7 @@ def do_train(args): if __name__ == "__main__": args = parse_args() - if args.n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + if args.n_cards > 1 and args.select_device == "gpu": + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_cards) else: do_train(args) diff --git a/PaddleNLP/paddlenlp/transformers/ernie/modeling.py b/PaddleNLP/paddlenlp/transformers/ernie/modeling.py index 14652160..2b2cfe7f 100644 --- a/PaddleNLP/paddlenlp/transformers/ernie/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/ernie/modeling.py @@ -18,11 +18,9 @@ import paddle.nn as nn from .. import PretrainedModel, register_base_model __all__ = [ - 'ErnieModel', - 'ErniePretrainedModel', - 'ErnieForSequenceClassification', - 'ErnieForTokenClassification', - 'ErnieForQuestionAnswering', + 'ErnieModel', 'ErniePretrainedModel', 'ErnieForSequenceClassification', + 'ErnieForTokenClassification', 'ErnieForQuestionAnswering', + 'ErnieForPretraining', 'ErniePretrainingCriterion' ] @@ -50,8 +48,11 @@ class ErnieEmbeddings(nn.Layer): def forward(self, input_ids, token_type_ids=None, position_ids=None): if position_ids is None: # maybe need use shape op to unify static graph and dynamic graph - seq_length = input_ids.shape[1] - position_ids = paddle.arange(0, seq_length, dtype="int64") + #seq_length = input_ids.shape[1] + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=1) + position_ids = seq_length - ones + position_ids.stop_gradient = True if token_type_ids is None: token_type_ids = paddle.zeros_like(input_ids, dtype="int64") @@ -168,13 +169,14 @@ class ErniePretrainedModel(PretrainedModel): if isinstance(layer, (nn.Linear, nn.Embedding)): # only support dygraph, use truncated_normal and make it inplace # and configurable later - layer.weight.set_value( - paddle.tensor.normal( - mean=0.0, - std=self.initializer_range - if hasattr(self, "initializer_range") else - self.ernie.config["initializer_range"], - shape=layer.weight.shape)) + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.initializer_range + if hasattr(self, "initializer_range") else + self.ernie.config["initializer_range"], + shape=layer.weight.shape)) @register_base_model @@ -320,3 +322,48 @@ class ErnieForTokenClassification(ErniePretrainedModel): sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) return logits + + +class ErnieForPretraining(ErniePretrainedModel): + def __init__(self, ernie): + super(ErnieForPretraining, self).__init__() + self.ernie = bert + self.cls = ErniePretrainingHeads( + self.ernie.config["hidden_size"], + self.ernie.config["vocab_size"], + self.ernie.config["hidden_act"], + embedding_weights=self.ernie.embeddings.word_embeddings.weight) + + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None, + masked_positions=None): + outputs = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output, masked_positions) + return prediction_scores, seq_relationship_score + + +class ErniePretrainingCriterion(paddle.nn.Layer): + def __init__(self, vocab_size): + super(ErniePretrainingCriterion, self).__init__() + self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward(self, prediction_scores, seq_relationship_score, + masked_lm_labels, next_sentence_labels, masked_lm_scale): + masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy( + prediction_scores, masked_lm_labels, ignore_index=-1) + masked_lm_loss = masked_lm_loss / masked_lm_scale + next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy( + seq_relationship_score, next_sentence_labels) + return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss) -- GitLab