未验证 提交 f8365ada 编写于 作者: C Chang Xu 提交者: GitHub

Add Eval for NLP/HuggingFace Demo (#1204)

上级 508a3afb
......@@ -103,7 +103,7 @@ tar -zxvf afqmc.tar
#### 3.4 自动压缩并产出模型
自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中训练部分的参数,将任务名称、模型类型、数据集名称、压缩参数传入,配置完成后便可对模型进行剪枝、蒸馏训练和离线量化。
数据集为CLUE,不同任务名称代表CLUE上不同的任务,可选择的任务名称有:afqmc, tnews, iflytek, ocnli, cmnli, cluewsc2020, csl。具体运行命令为
数据集为CLUE,不同任务名称代表CLUE上不同的任务,可选择的任务名称有:```afqmc, tnews, iflytek, ocnli, cmnli, cluewsc2020, csl```。具体运行命令为:
```shell
export CUDA_VISIBLE_DEVICES=0
......@@ -119,6 +119,8 @@ python run.py \
--task_name='afqmc' \
--config_path='./configs/pp-minilm/auto/afqmc.yaml'
```
如仅需验证模型精度,在启动```run.py```脚本时,命令加上```--eval=True```即可。
## 4. 压缩配置介绍
自动压缩需要准备config文件,并传入```config_path```字段,configs文件夹下可查看不同任务的配置文件,以下示例以afqmc数据集为例介绍。训练参数需要自行配置。蒸馏、剪枝和离线量化的相关配置,自动压缩策略可以自动获取得到,也可以自行配置。PaddleNLP模型的自动压缩实验默认使用剪枝、蒸馏和离线量化的策略。
......@@ -199,8 +201,7 @@ Quantization:
## 5. 预测部署
- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md)
- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md)
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md)
- [PP-MiniLM Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/model_compression/pp-minilm)
- [ERNIE-3.0 Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/ernie-3.0)
## 6. FAQ
......@@ -33,8 +33,10 @@ add_arg('dataset', str, None, "datset name.")
add_arg('save_dir', str, None, "directory to save compressed model.")
add_arg('max_seq_length', int, 128, "max sequence length after tokenization.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('task_name', str, 'sst-2', "task name in glue.")
add_arg('task_name', str, 'sst-2', "task name in glue.")
add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('eval', bool, False, "whether validate the model only.")
# yapf: enable
METRIC_CLASSES = {
......@@ -226,6 +228,35 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
return res
def eval():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
print('Loaded model from: {}'.format(args.model_dir))
metric.reset()
print('Evaluating...')
for data in eval_dataloader():
logits = exe.run(val_program,
feed={
feed_target_names[0]: data[0]['input_ids'],
feed_target_names[1]: data[0]['token_type_ids']
},
fetch_list=fetch_targets)
paddle.disable_static()
labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten())
logits_pd = paddle.to_tensor(logits[0])
correct = metric.compute(logits_pd, labels_pd)
metric.update(correct)
paddle.enable_static()
res = metric.accumulate()
return res
def apply_decay_param_fun(name):
if name.find("bias") > -1:
return True
......@@ -250,6 +281,11 @@ if __name__ == '__main__':
metric_class = METRIC_CLASSES[args.task_name]
metric = metric_class()
if args.eval:
result = eval()
print('Eval metric:', result)
sys.exit(0)
ac = AutoCompression(
model_dir=args.model_dir,
model_filename=args.model_filename,
......
export CUDA_VISIBLE_DEVICES=0
export FLAGS_cudnn_deterministic=True
python run.py \
--model_type='ppminilm' \
......
......@@ -15,8 +15,7 @@
飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。
本示例将以[PyTorch](https://github.com/pytorch/pytorch)框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用[huggingface](https://github.com/huggingface/transformers)开源transformers库,将PyTorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(```Post-training quantization```)。
本示例将以[Pytorch](https://github.com/pytorch/pytorch)框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用[huggingface](https://github.com/huggingface/transformers)开源transformers库,将Pytorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(```Post-training quantization```)。
......@@ -25,16 +24,16 @@
[BERT](https://arxiv.org/abs/1810.04805)```Bidirectional Encoder Representations from Transformers```)以Transformer 编码器为网络基本组件,使用掩码语言模型(```Masked Language Model```)和邻接句子预测(```Next Sentence Prediction```)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的NLP任务,效果通常也较直接在下游的任务上训练的模型更优。此前BERT即在[GLUE](https://gluebenchmark.com/tasks)评测任务上取得了SOTA的结果。
基于bert-base-cased模型,压缩前后的精度如下:
| 模型 | 策略 | CoLA | MRPC | QNLI | QQP | RTE | SST2 | AVG |
|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|:------:|:------:|
| bert-base-cased | Base模型| 60.06 | 84.31 | 90.68 | 90.84 | 63.53 | 91.63 | 80.17 |
| bert-base-cased |剪枝蒸馏+离线量化| 60.52 | 84.80 | 90.59 | 90.42 | 64.26 | 91.63 | 80.37 |
| 模型 | 策略 | CoLA | MRPC | QNLI | QQP | RTE | SST2 | STSB | AVG |
|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|:------:|:------:|:------:|
| bert-base-cased | Base模型| 60.06 | 84.31 | 90.68 | 90.84 | 63.53 | 91.63 | 88.46 | 81.35 |
| bert-base-cased |剪枝蒸馏+离线量化| 60.52 | 84.80 | 90.59 | 90.42 | 64.26 | 91.63 | 88.51 | 81.53 |
模型在多个任务上平均精度以及加速对比如下:
| bert-base-cased | Accuracy(avg) | 时延(ms) | 加速比 |
|:-------:|:----------:|:------------:| :------:|
| 压缩前 | 80.17 | 8.18 | - |
| 压缩后 | 80.37 | 6.35 | 28.82% |
| 压缩前 | 81.35 | 8.18 | - |
| 压缩后 | 81.53 | 6.35 | 1.29 |
- Nvidia GPU 测试环境:
- 硬件:NVIDIA Tesla T4 单卡
......@@ -88,8 +87,7 @@ pip install paddlenlp
#### 3.3 X2Paddle转换模型流程
**方式1: PyTorch2Paddle直接将PyTorch动态图模型转为Paddle静态图模型**
**方式1: PyTorch2Paddle直接将Pytorch动态图模型转为Paddle静态图模型**
```shell
import torch
......@@ -98,8 +96,8 @@ import numpy as np
torch_model.eval()
# 构建输入
input_ids = torch.unsqueeze(torch.tensor([0] * max_length), 0)
token_type_ids = torch.unsqueeze(torch.tensor([0] * max_length),0)
attention_msk = torch.unsqueeze(torch.tensor([0] * max_length),0)
token_type_ids = torch.unsqueeze(torch.tensor([0] * max_length), 0)
attention_msk = torch.unsqueeze(torch.tensor([0] * max_length), 0)
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_model,
......@@ -118,18 +116,17 @@ PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图
- 使用PaddleNLP的tokenizer时需要在模型保存的文件夹中加入```model_config.json, special_tokens_map.json, tokenizer_config.json, vocab.txt```这些文件。
更多PyTorch2Paddle示例可参考[PyTorch模型转换文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md)。其他框架转换可参考[X2Paddle模型转换工具](https://github.com/PaddlePaddle/X2Paddle)
更多Pytorch2Paddle示例可参考[PyTorch模型转换文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md)。其他框架转换可参考[X2Paddle模型转换工具](https://github.com/PaddlePaddle/X2Paddle)
如想快速尝试运行实验,也可以直接下载已经转换好的模型,链接如下:
| [CoLA](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar) | [MRPC](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_mrpc.tar) | [QNLI](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qnli.tar) | [QQP](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qqp.tar) | [RTE](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_rte.tar) | [SST2](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_sst2.tar) |
| [CoLA](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar) | [MRPC](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_mrpc.tar) | [QNLI](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qnli.tar) | [QQP](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qqp.tar) | [RTE](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_rte.tar) | [SST2](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_sst2.tar) | [STSB](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_stsb.tar) |
```shell
wget https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar
tar xf x2paddle_cola.tar
```
**方式2: Onnx2Paddle将PyTorch动态图模型保存为Onnx格式后再转为Paddle静态图模型**
**方式2: Onnx2Paddle将Pytorch动态图模型保存为Onnx格式后再转为Paddle静态图模型**
PyTorch 导出 ONNX 动态图模型
......@@ -179,17 +176,39 @@ def main(x0, x1, x2):
#### 3.4 自动压缩并产出模型
以“cola”任务为例,在配置文件“./config/cola.yaml”中配置推理模型路径、压缩策略参数等信息,并通过“--config_path”将配置文件传给示例脚本"run.py"。
在“run.py”中,调用接口```paddleslim.auto_compression.AutoCompression```加载配置文件,并对推理模型进行自动压缩。
```cola```任务为例,在配置文件```./config/cola.yaml```中配置推理模型路径、压缩策略参数等信息,并通过```--config_path```将配置文件传给示例脚本```run.py```
```run.py```中,调用接口```paddleslim.auto_compression.AutoCompression```加载配置文件,使用以下命令对推理模型进行自动压缩:
```shell
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/'
```
如仅需验证模型精度,在启动```run.py```脚本时,命令加上```--eval True```即可:
```shell
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/' --eval True
```
## 4. 预测部署
- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md)
- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md)
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md)
准备好inference模型后,可以使用```infer.py```进行预测,比如:
```shell
python -u ./infer.py \
--task_name cola \
--model_name_or_path bert-base-cased \
--model_path ./x2paddle_cola/model \
--batch_size 1 \
--max_seq_length 128 \
--device gpu \
--use_trt \
```
除需传入```task_name```任务名称,```model_name_or_path```模型名称,```model_path```保存inference模型的路径等基本参数外,还需根据预测环境传入预测参数:
- ```device```:默认为gpu,可选为gpu, cpu, xpu
- ```use_trt```:是否使用 TesorRT 预测引擎
- ```int8```:是否启用```INT8```
- ```fp16```:是否启用```FP16```
## 5. FAQ
......@@ -43,8 +43,14 @@ def argsparser():
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--eval',
type=bool,
default=False,
help="whether validate the model only.")
return parser
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
......@@ -263,6 +269,42 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
return res[0] if isinstance(res, list) or isinstance(res, tuple) else res
def eval():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"],
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric.reset()
print('Evaluating...')
for data in eval_dataloader():
logits = exe.run(val_program,
feed={
feed_target_names[0]: data[0]['x0'],
feed_target_names[1]: data[0]['x1'],
feed_target_names[2]: data[0]['x2']
},
fetch_list=fetch_targets)
paddle.disable_static()
if isinstance(metric, PearsonAndSpearman):
labels_pd = paddle.to_tensor(np.array(data[0]['label'])).reshape(
(-1, 1))
logits_pd = paddle.to_tensor(logits[0]).reshape((-1, 1))
metric.update((logits_pd, labels_pd))
else:
labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten())
logits_pd = paddle.to_tensor(logits[0])
correct = metric.compute(logits_pd, labels_pd)
metric.update(correct)
paddle.enable_static()
res = metric.accumulate()
return res[0] if isinstance(res, list) or isinstance(res, tuple) else res
def apply_decay_param_fun(name):
if name.find("bias") > -1:
return True
......@@ -292,6 +334,11 @@ def main():
metric_class = METRIC_CLASSES[global_config['task_name']]
metric = metric_class()
if args.eval:
result = eval()
print('Eval metric:', result)
sys.exit(0)
ac = AutoCompression(
model_dir=global_config['model_dir'],
model_filename=global_config['model_filename'],
......
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册