未验证 提交 f07cdf53 编写于 作者: J jeff41404 提交者: GitHub

add electra pretrain and modify style of electra modeling (#4990)

* add electra pretrain and modify style of electra modeling

* add electra pretrain, modify style of electra modeling and fix problems of review

* delete predict_classifer

* modify accu to acc

* add paddlenlp.metrics.glue
上级 8e45228d
...@@ -25,25 +25,27 @@ from functools import partial ...@@ -25,25 +25,27 @@ from functools import partial
import numpy as np import numpy as np
import paddle import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.metric import Metric, Accuracy, Precision, Recall
from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE
from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.data.sampler import SamplerHelper from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TASK_CLASSES = { TASK_CLASSES = {
"cola": (GlueCoLA, paddle.metric.Accuracy), "cola": (GlueCoLA, Mcc),
"sst-2": (GlueSST2, paddle.metric.Accuracy), "sst-2": (GlueSST2, Accuracy),
"mrpc": (GlueMRPC, paddle.metric.Accuracy), "mrpc": (GlueMRPC, AccuracyAndF1),
"sts-b": (GlueSTSB, paddle.metric.Accuracy), "sts-b": (GlueSTSB, PearsonAndSpearman),
"qqp": (GlueQQP, paddle.metric.Accuracy), "qqp": (GlueQQP, AccuracyAndF1),
"mnli": (GlueMNLI, paddle.metric.Accuracy), "mnli": (GlueMNLI, Accuracy),
"qnli": (GlueQNLI, paddle.metric.Accuracy), "qnli": (GlueQNLI, Accuracy),
"rte": (GlueRTE, paddle.metric.Accuracy), "rte": (GlueRTE, Accuracy),
} }
MODEL_CLASSES = { MODEL_CLASSES = {
...@@ -57,21 +59,17 @@ def set_seed(args): ...@@ -57,21 +59,17 @@ def set_seed(args):
paddle.seed(args.seed + paddle.distributed.get_rank()) paddle.seed(args.seed + paddle.distributed.get_rank())
def evaluate(model, loss_fct, metric, data_loader, return_dict): def evaluate(model, loss_fct, metric, data_loader):
model.eval() model.eval()
metric.reset() metric.reset()
for batch in data_loader: for batch in data_loader:
input_ids, segment_ids, labels = batch input_ids, segment_ids, labels = batch
model_output = model(input_ids=input_ids, token_type_ids=segment_ids) logits = model(input_ids=input_ids, token_type_ids=segment_ids)
if not return_dict:
logits = model_output[0]
else:
logits = model_output.logits
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
correct = metric.compute(logits, labels) correct = metric.compute(logits, labels)
metric.update(correct) metric.update(correct)
accu = metric.accumulate() acc = metric.accumulate()
print("eval loss: %f, accu: %f, " % (loss.numpy(), accu), end='') print("eval loss: %f, acc: %s, " % (loss.numpy(), acc), end='')
model.train() model.train()
...@@ -218,9 +216,10 @@ def do_train(args): ...@@ -218,9 +216,10 @@ def do_train(args):
num_workers=0, num_workers=0,
return_list=True) return_list=True)
num_labels = 1 if train_dataset.get_labels() == None else len(
train_dataset.get_labels())
model = model_class.from_pretrained( model = model_class.from_pretrained(
args.model_name_or_path, num_labels=len(train_dataset.get_labels())) args.model_name_or_path, num_labels=num_labels)
return_dict = model.return_dict
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -267,14 +266,14 @@ def do_train(args): ...@@ -267,14 +266,14 @@ def do_train(args):
tic_train = time.time() tic_train = time.time()
for epoch in range(args.num_train_epochs): for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_data_loader): for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch input_ids, segment_ids, labels = batch
model_output = model( logits = model(input_ids=input_ids, token_type_ids=segment_ids)
input_ids=input_ids, token_type_ids=segment_ids)
if not return_dict:
logits = model_output[0]
else:
logits = model_output.logits
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.logging_steps == 0: if global_step % args.logging_steps == 0:
print( print(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
...@@ -282,21 +281,15 @@ def do_train(args): ...@@ -282,21 +281,15 @@ def do_train(args):
paddle.distributed.get_rank(), loss, optimizer.get_lr(), paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train))) args.logging_steps / (time.time() - tic_train)))
tic_train = time.time() tic_train = time.time()
loss.backward() if global_step % args.save_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step > 1 and global_step % args.save_steps == 0:
tic_eval = time.time() tic_eval = time.time()
if args.task_name == "mnli": if args.task_name == "mnli":
evaluate(model, loss_fct, metric, dev_data_loader_matched, evaluate(model, loss_fct, metric, dev_data_loader_matched)
return_dict)
evaluate(model, loss_fct, metric, evaluate(model, loss_fct, metric,
dev_data_loader_mismatched, return_dict) dev_data_loader_mismatched)
print("eval done total : %s s" % (time.time() - tic_eval)) print("eval done total : %s s" % (time.time() - tic_eval))
else: else:
evaluate(model, loss_fct, metric, dev_data_loader, evaluate(model, loss_fct, metric, dev_data_loader)
return_dict)
print("eval done total : %s s" % (time.time() - tic_eval)) print("eval done total : %s s" % (time.time() - tic_eval))
if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
output_dir = os.path.join(args.output_dir, output_dir = os.path.join(args.output_dir,
...@@ -309,7 +302,6 @@ def do_train(args): ...@@ -309,7 +302,6 @@ def do_train(args):
model, paddle.DataParallel) else model model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
global_step += 1
def get_md5sum(file_path): def get_md5sum(file_path):
...@@ -374,7 +366,7 @@ if __name__ == "__main__": ...@@ -374,7 +366,7 @@ if __name__ == "__main__":
"than this will be truncated, sequences shorter will be padded.", ) "than this will be truncated, sequences shorter will be padded.", )
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
default=3e-4, default=1e-4,
type=float, type=float,
help="The initial learning rate for Adam.") help="The initial learning rate for Adam.")
parser.add_argument( parser.add_argument(
......
此差异已折叠。
...@@ -14,4 +14,5 @@ ...@@ -14,4 +14,5 @@
from .perplexity import Perplexity from .perplexity import Perplexity
from .chunk import ChunkEvaluator from .chunk import ChunkEvaluator
from .bleu import BLEU from .bleu import BLEU
\ No newline at end of file from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
from functools import partial
import numpy as np
import paddle
from paddle.metric import Metric, Accuracy, Precision, Recall
__all__ = ['AccuracyAndF1', 'Mcc', 'PearsonAndSpearman']
class AccuracyAndF1(Metric):
"""
Encapsulates Accuracy, Precision, Recall and F1 metric logic.
"""
def __init__(self,
topk=(1, ),
pos_label=1,
name='acc_and_f1',
*args,
**kwargs):
super(AccuracyAndF1, self).__init__(*args, **kwargs)
self.topk = topk
self.pos_label = pos_label
self._name = name
self.acc = Accuracy(self.topk, *args, **kwargs)
self.precision = Precision(*args, **kwargs)
self.recall = Recall(*args, **kwargs)
self.reset()
def compute(self, pred, label, *args):
self.label = label
self.preds_pos = paddle.nn.functional.softmax(pred)[:, self.pos_label]
return self.acc.compute(pred, label)
def update(self, correct, *args):
self.acc.update(correct)
self.precision.update(self.preds_pos, self.label)
self.recall.update(self.preds_pos, self.label)
def accumulate(self):
acc = self.acc.accumulate()
precision = self.precision.accumulate()
recall = self.recall.accumulate()
if precision == 0.0 or recall == 0.0:
f1 = 0.0
else:
# 1/f1 = 1/2 * (1/precision + 1/recall)
f1 = (2 * precision * recall) / (precision + recall)
return (
acc,
precision,
recall,
f1,
(acc + f1) / 2, )
def reset(self):
self.acc.reset()
self.precision.reset()
self.recall.reset()
self.label = None
self.preds_pos = None
def name(self):
"""
Return name of metric instance.
"""
return self._name
class Mcc(Metric):
"""
Matthews correlation coefficient
https://en.wikipedia.org/wiki/Matthews_correlation_coefficient.
"""
def __init__(self, name='mcc', *args, **kwargs):
super(Mcc, self).__init__(*args, **kwargs)
self._name = name
self.tp = 0 # true positive
self.fp = 0 # false positive
self.tn = 0 # true negative
self.fn = 0 # false negative
def compute(self, pred, label, *args):
preds = paddle.argsort(pred, descending=True)[:, :1]
return (preds, label)
def update(self, preds_and_labels):
preds = preds_and_labels[0]
preds = preds.numpy()
labels = preds_and_labels[1]
labels = labels.numpy().reshape(-1, 1)
sample_num = labels.shape[0]
for i in range(sample_num):
pred = preds[i]
label = labels[i]
if pred == 1:
if pred == label:
self.tp += 1
else:
self.fp += 1
else:
if pred == label:
self.tn += 1
else:
self.fn += 1
def accumulate(self):
if self.tp == 0 or self.fp == 0 or self.tn == 0 or self.fn == 0:
mcc = 0.0
else:
# mcc = (tp*tn-fp*fn)/ sqrt(tp+fp)(tp+fn)(tn+fp)(tn+fn))
mcc = (self.tp * self.tn - self.fp * self.fn) / math.sqrt(
(self.tp + self.fp) * (self.tp + self.fn) *
(self.tn + self.fp) * (self.tn + self.fn))
return (mcc, )
def reset(self):
self.tp = 0 # true positive
self.fp = 0 # false positive
self.tn = 0 # true negative
self.fn = 0 # false negative
def name(self):
"""
Return name of metric instance.
"""
return self._name
class PearsonAndSpearman(Metric):
"""
Pearson correlation coefficient
https://en.wikipedia.org/wiki/Pearson_correlation_coefficient
Spearman's rank correlation coefficient
https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient.
"""
def __init__(self, name='mcc', *args, **kwargs):
super(PearsonAndSpearman, self).__init__(*args, **kwargs)
self._name = name
self.preds = []
self.labels = []
def update(self, preds_and_labels):
preds = preds_and_labels[0]
preds = np.squeeze(preds.numpy().reshape(-1, 1)).tolist()
labels = preds_and_labels[1]
labels = np.squeeze(labels.numpy().reshape(-1, 1)).tolist()
self.preds.append(preds)
self.labels.append(labels)
def accumulate(self):
preds = [item for sublist in self.preds for item in sublist]
labels = [item for sublist in self.labels for item in sublist]
#import pdb; pdb.set_trace()
pearson = self.pearson(preds, labels)
spearman = self.spearman(preds, labels)
return (
pearson,
spearman,
(pearson + spearman) / 2, )
def pearson(self, preds, labels):
n = len(preds)
#simple sums
sum1 = sum(float(preds[i]) for i in range(n))
sum2 = sum(float(labels[i]) for i in range(n))
#sum up the squares
sum1_pow = sum([pow(v, 2.0) for v in preds])
sum2_pow = sum([pow(v, 2.0) for v in labels])
#sum up the products
p_sum = sum([preds[i] * labels[i] for i in range(n)])
numerator = p_sum - (sum1 * sum2 / n)
denominator = math.sqrt(
(sum1_pow - pow(sum1, 2) / n) * (sum2_pow - pow(sum2, 2) / n))
if denominator == 0:
return 0.0
return numerator / denominator
def spearman(self, preds, labels):
preds_rank = self.get_rank(preds)
labels_rank = self.get_rank(labels)
total = 0
n = len(preds)
for i in range(n):
total += pow((preds_rank[i] - labels_rank[i]), 2)
spearman = 1 - float(6 * total) / (n * (pow(n, 2) - 1))
return spearman
def get_rank(self, raw_list):
x = np.array(raw_list)
r_x = np.empty(x.shape, dtype=int)
y = np.argsort(-x)
for i, k in enumerate(y):
r_x[k] = i + 1
return r_x
def reset(self):
self.preds = []
self.labels = []
def name(self):
"""
Return name of metric instance.
"""
return self._name
...@@ -47,47 +47,32 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -47,47 +47,32 @@ class ElectraTokenizer(PretrainedTokenizer):
resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained
pretrained_resource_files_map = { pretrained_resource_files_map = {
"vocab_file": { "vocab_file": {
"electra-small-generator": "electra-small":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-generator/vocab.txt", "https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-vocab.txt",
"electra-base-generator": "electra-base":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-generator/vocab.txt", "https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-vocab.txt",
"electra-large-generator": "electra-large":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-generator/vocab.txt", "https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-vocab.txt",
"electra-small-discriminator": "chinese-electra-base":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-small-discriminator/vocab.txt", "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-base/vocab.txt",
"electra-base-discriminator": "chinese-electra-small":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-base-discriminator/vocab.txt", "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-small/vocab.txt",
"electra-large-discriminator":
"https://paddlenlp.bj.bcebos.com/models/transformers/electra-large-discriminator/vocab.txt",
"chinese-electra-discriminator-base":
"http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-base/vocab.txt",
"chinese-electra-discriminator-small":
"http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-discriminator-small/vocab.txt",
} }
} }
pretrained_init_configuration = { pretrained_init_configuration = {
"electra-small-generator": { "electra-small": {
"do_lower_case": True "do_lower_case": True
}, },
"electra-base-generator": { "electra-base": {
"do_lower_case": True "do_lower_case": True
}, },
"electra-large-generator": { "electra-large": {
"do_lower_case": True "do_lower_case": True
}, },
"electra-small-discriminator": { "chinese-electra-base": {
"do_lower_case": True "do_lower_case": True
}, },
"electra-base-discriminator": { "chinese-electra-small": {
"do_lower_case": True
},
"electra-large-discriminator": {
"do_lower_case": True
},
"chinese-electra-discriminator-base": {
"do_lower_case": True
},
"chinese-electra-discriminator-small": {
"do_lower_case": True "do_lower_case": True
} }
} }
...@@ -163,15 +148,12 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -163,15 +148,12 @@ class ElectraTokenizer(PretrainedTokenizer):
def num_special_tokens_to_add(self, pair=False): def num_special_tokens_to_add(self, pair=False):
""" """
Returns the number of added tokens when encoding a sequence with special tokens. Returns the number of added tokens when encoding a sequence with special tokens.
Note: Note:
This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
inside your training loop. inside your training loop.
Args: Args:
pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
number of added tokens in the case of a single sequence if set to False. number of added tokens in the case of a single sequence if set to False.
Returns: Returns:
Number of tokens added to sequences Number of tokens added to sequences
""" """
...@@ -190,13 +172,11 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -190,13 +172,11 @@ class ElectraTokenizer(PretrainedTokenizer):
:: ::
- single sequence: ``[CLS] X [SEP]`` - single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]`` - pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args: Args:
token_ids_0 (:obj:`List[int]`): token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added. List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`): token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs. Optional second list of IDs for sequence pairs.
Returns: Returns:
:obj:`List[int]`: List of input_id with the appropriate special tokens. :obj:`List[int]`: List of input_id with the appropriate special tokens.
""" """
...@@ -211,21 +191,16 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -211,21 +191,16 @@ class ElectraTokenizer(PretrainedTokenizer):
token_ids_1=None): token_ids_1=None):
""" """
Create a mask from the two sequences passed to be used in a sequence-pair classification task. Create a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format: A BERT sequence pair mask has the following format:
:: ::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence | | first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args: Args:
token_ids_0 (:obj:`List[int]`): token_ids_0 (:obj:`List[int]`):
List of IDs. List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`): token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs. Optional second list of IDs for sequence pairs.
Returns: Returns:
:obj:`List[int]`: List of token_type_id according to the given sequence(s). :obj:`List[int]`: List of token_type_id according to the given sequence(s).
""" """
...@@ -251,7 +226,6 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -251,7 +226,6 @@ class ElectraTokenizer(PretrainedTokenizer):
""" """
Returns a dictionary containing the encoded sequence or sequence pair and additional information: Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_seq_len`` is specified. the mask for sequence classification and the overflowing elements if a ``max_seq_len`` is specified.
Args: Args:
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`): text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
...@@ -270,7 +244,6 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -270,7 +244,6 @@ class ElectraTokenizer(PretrainedTokenizer):
model's max length. model's max length.
truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`): truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`):
String selected in the following options: String selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_seq_len - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_seq_len
starting from the longest one at each token (when there is a pair of input sequences) starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
...@@ -288,10 +261,8 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -288,10 +261,8 @@ class ElectraTokenizer(PretrainedTokenizer):
Set to True to return overflowing token information (default False). Set to True to return overflowing token information (default False).
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True to return special tokens mask information (default False). Set to True to return special tokens mask information (default False).
Return: Return:
A Dictionary of shape:: A Dictionary of shape::
{ {
input_ids: list[int], input_ids: list[int],
position_ids: list[int] if return_position_ids is True (default) position_ids: list[int] if return_position_ids is True (default)
...@@ -302,9 +273,7 @@ class ElectraTokenizer(PretrainedTokenizer): ...@@ -302,9 +273,7 @@ class ElectraTokenizer(PretrainedTokenizer):
num_truncated_tokens: int if a ``max_seq_len`` is specified and return_overflowing_tokens is True num_truncated_tokens: int if a ``max_seq_len`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if return_special_tokens_mask is True special_tokens_mask: list[int] if return_special_tokens_mask is True
} }
With the fields: With the fields:
- ``input_ids``: list of token ids to be fed to a model - ``input_ids``: list of token ids to be fed to a model
- ``position_ids``: list of token position ids to be fed to a model - ``position_ids``: list of token position ids to be fed to a model
- ``segment_ids``: list of token type ids to be fed to a model - ``segment_ids``: list of token type ids to be fed to a model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册