未验证 提交 c29f5ca1 编写于 作者: J jeff41404 提交者: GitHub

add grad clip to pretrain, modify glue readme and improve glue score … (#5093)

* add grad clip to pretrain, modify glue readme and improve glue score print

* add ernie-2.0 in run_glue.py and readme
上级 6ef01ddb
......@@ -480,11 +480,13 @@ def do_train(args):
float(num_training_steps - current_step) / float(
max(1, num_training_steps - num_warmup_steps))))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
......
# GLUE with PaddleNLP
[GLUE](https://gluebenchmark.com/)是当今使用最为普遍的自然语言理解评测基准数据集,评测数据涵盖新闻、电影、百科等许多领域,其中有简单的句子,也有困难的句子。其目的是通过公开的得分榜,促进自然语言理解系统的发展。详细参考 [GLUE论文](https://openreview.net/pdf?id=rJ4km2R5t7)
[GLUE](https://gluebenchmark.com/)是当今使用最为普遍的自然语言理解评测基准数据集,评测数据涵盖新闻、电影、百科等许多领域,其中有简单的句子,也有困难的句子。其目的是通过公开的得分榜,促进自然语言理解系统的发展。详细参考 [GLUE论文](https://openreview.net/pdf?id=rJ4km2R5t7)
本项目是 GLUE评测任务 在 Paddle 2.0上的开源实现。
## 发布要点
## 1. 发布要点
1. 支持CoLA、SST-2、MRPC、STS-B、QQP、MNLI、QNLI、RTE 8个GLUE评测任务的Fine-tuning。
2. 支持 BERT、ELECTRA 等预训练模型运行这些GLUE评测任务。
## NLP 任务的 Fine-tuning
运行Fine-tuning有两种方式:
1. 使用已有的预训练模型运行 Fine-tuning。
2. 运行特定模型(如BERT、ELECTRA等)的预训练后,使用预训练模型运行 Fine-tuning(需要很多资源)。
## 2. 快速开始
以下例子基于方式1。
### 2.1 环境配置
- Python >= 3.6
- paddlepaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)
- paddlenlp >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b`
### 语句和句对分类任务
以 GLUE/SST-2 任务为例,启动 Fine-tuning 的方式如下(`paddlenlp` 要已经安装或能在 `PYTHONPATH` 中找到):
### 2.2 启动GLUE任务
以 GLUE/SST-2 任务为例,启动GLUE任务进行Fine-tuning 的方式如下:
```shell
export CUDA_VISIBLE_DEVICES=0,1
......@@ -33,15 +32,30 @@ python -u ./run_glue.py \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 1 \
--save_steps 500 \
--save_steps 100 \
--output_dir ./tmp/$TASK_NAME/ \
--n_gpu 1 \
```
其中参数释义如下:
- `model_type` 指示了模型类型,当前支持BERT、ELECTRA模型。
- `model_name_or_path` 指示了使用哪种预训练模型,对应有其预训练模型和预训练时使用的 tokenizer,当前支持bert-base-uncased、bert-large-uncased、bert-base-cased、bert-large-cased、bert-base-multilingual-uncased、bert-base-multilingual-cased、bert-base-chinese、bert-wwm-chinese、bert-wwm-ext-chinese、electra-small、electra-base、electra-large、chinese-electra-base、chinese-electra-small等模型。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `model_type` 指示了Fine-tuning使用的预训练模型类型,如:bert、electra、ernie等,因不同类型的预训练模型可能有不同的 Fine-tuning layer 和 tokenizer。
- `model_name_or_path` 指示了Fine-tuning使用的具体预训练模型,可以是PaddleNLP提供的预训练模型 或者 本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: /home/xx_model/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面某个,但是注意这里选择的模型要和上面配置的模型类型匹配,如:model_type 配置的是bert,则model_name_or_path只能选择bert相关的模型(下表中bert开头的9个)
| PaddleNLP提供的预训练模型 |
|---------------------------------|
| ernie-2.0-en |
| ernie-2.0-large-en |
| bert-base-uncased |
| bert-large-uncased |
| bert-base-cased |
| bert-large-cased |
| bert-base-multilingual-uncased |
| bert-base-multilingual-cased |
| electra-small |
| electra-base |
| electra-large |
- `task_name` 表示 Fine-tuning 的任务,当前支持CoLA、SST-2、MRPC、STS-B、QQP、MNLI、QNLI、RTE。
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
......@@ -64,17 +78,24 @@ global step 6315/6315, epoch: 2, batch: 2104, rank_id: 0, loss: 0.046043, lr: 0.
eval loss: 0.549763, acc: 0.9151376146788991, eval done total : 1.8206987380981445 s
```
使用electra-small预训练模型进行单卡 Fine-tuning ,在验证集上有如下结果:
| Task | Metric | Result |
|-------|------------------------------|-------------|
| CoLA | Matthews corr | 58.22 |
| SST-2 | acc. | 91.85 |
| MRPC | acc./F1 | 88.24 |
| STS-B | Pearson/Spearman corr | 87.24 |
| QQP | acc./F1 | 88.83 |
| MNLI | matched acc./mismatched acc. | 82.45 |
| QNLI | acc. | 88.61 |
| RTE | acc. | 66.78 |
注:acc.是Accuracy的简称,表中Metric字段名词取自[GLUE论文](https://openreview.net/pdf?id=rJ4km2R5t7)
使用各种预训练模型进行 Fine-tuning ,在GLUE验证集上有如下结果:
| Model GLUE Score | CoLA | SST-2 | MRPC | STS-B | QQP | MNLI | QNLI | RTE |
|--------------------|-------|--------|--------|--------|--------|--------|--------|--------|
| electra-small | 58.22 | 91.85 | 88.24 | 87.24 | 88.83 | 82.45 | 88.61 | 66.78 |
| ernie-2.0-large-en | 65.4 | 96.0 | 88.7 | 92.3 | 92.5 | 89.1 | 94.3 | 85.2 |
关于GLUE Score的说明:
1. 因Fine-tuning过程中有dropout等随机因素影响,同样预训练模型每次运行的GLUE Score会有较小差异,上表中的GLUE Score是运行多次取eval最好值的得分。
2. 不同GLUE任务判定得分所使用的评价指标有些差异,简单如下表,详细说明可参考[GLUE论文](https://openreview.net/pdf?id=rJ4km2R5t7)
| GLUE Task | Metric |
|------------|------------------------------|
| CoLA | Matthews corr |
| SST-2 | acc. |
| MRPC | acc./F1 |
| STS-B | Pearson/Spearman corr |
| QQP | acc./F1 |
| MNLI | matched acc./mismatched acc. |
| QNLI | acc. |
| RTE | acc. |
......@@ -31,6 +31,7 @@ from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -51,6 +52,7 @@ TASK_CLASSES = {
MODEL_CLASSES = {
"bert": (BertForSequenceClassification, BertTokenizer),
"electra": (ElectraForSequenceClassification, ElectraTokenizer),
"ernie": (ErnieForSequenceClassification, ErnieTokenizer),
}
......@@ -174,8 +176,27 @@ def evaluate(model, loss_fct, metric, data_loader):
loss = loss_fct(logits, labels)
correct = metric.compute(logits, labels)
metric.update(correct)
acc = metric.accumulate()
print("eval loss: %f, acc: %s, " % (loss.numpy(), acc), end='')
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"eval loss: %f, acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, "
% (
loss.numpy(),
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
elif isinstance(metric, Mcc):
print("eval loss: %f, mcc: %s, " % (loss.numpy(), res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"eval loss: %f, pearson: %s, spearman: %s, pearson and spearman: %s, "
% (loss.numpy(), res[0], res[1], res[2]),
end='')
else:
print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='')
model.train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册