未验证 提交 b3499dc9 编写于 作者: C Chang Xu 提交者: GitHub

Add X2Paddle & NLP Demo (#1190)

上级 d69d9822
# HuggingFace 预训练模型压缩部署示例
目录:
- [1. 简介](#1简介)
- [2. Benchmark](#2Benchmark)
- [3. 自动压缩流程](#自动压缩流程)
- [3.1 准备环境](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集)
- [3.3 X2Paddle转换模型流程](#33-X2Paddle转换模型流程)
- [3.4 自动压缩并产出模型](#34-自动压缩并产出模型)
- [4. 压缩配置介绍](#4压缩配置介绍)
- [5. 预测部署](#5预测部署)
- [6. FAQ](6FAQ)
## 1. 简介
飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。
本示例将以[Pytorch](https://github.com/pytorch/pytorch)框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用[huggingface](https://github.com/huggingface/transformers)开源transformers库,将Pytorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(```Post-training quantization```)。
## 2. Benchmark
[BERT](https://arxiv.org/abs/1810.04805)```Bidirectional Encoder Representations from Transformers```)以Transformer 编码器为网络基本组件,使用掩码语言模型(```Masked Language Model```)和邻接句子预测(```Next Sentence Prediction```)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的NLP任务,效果通常也较直接在下游的任务上训练的模型更优。此前BERT即在[GLUE](https://gluebenchmark.com/tasks)评测任务上取得了SOTA的结果。
基于bert-base-cased模型,压缩前后的精度如下:
| 模型 | 策略 | CoLA | MRPC | QNLI | QQP | RTE | SST2 | AVG |
|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|:------:|:------:|
| bert-base-cased | Base模型| 60.06 | 84.31 | 90.68 | 90.84 | 63.53 | 91.63 | 80.17 |
| bert-base-cased |剪枝蒸馏+离线量化| 60.52 | 84.80 | 90.59 | 90.42 | 64.26 | 91.63 | 80.37 |
模型在多个任务上平均精度以及加速对比如下:
| bert-base-cased | Accuracy(avg) | 时延(ms) | 加速比 |
|:-------:|:----------:|:------------:| :------:|
| 压缩前 | 80.17 | 8.18 | - |
| 压缩后 | 80.37 | 6.35 | 28.82% |
- Nvidia GPU 测试环境:
- 硬件:NVIDIA Tesla T4 单卡
- 软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4
- 测试配置:batch_size: 1, seqence length: 128
## 3. 自动压缩流程
#### 3.1 准备环境
- python >= 3.6
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim develop版本或PaddleSlim>=2.3.0
- X2Paddle develop版本
- PaddleNLP >= 2.3
- tensorflow == 1.14 (如需压缩TensorFlow模型)
- onnx >= 1.6.0 (如需压缩ONNX模型)
- torch >= 1.5.0 (如需压缩PyTorch模型)
安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
```
安装paddleslim:
```shell
git clone https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py install
```
安装X2Paddle:
```
git clone https://github.com/PaddlePaddle/X2Paddle.git
cd X2Paddle
git checkout develop
python setup.py install
```
安装paddlenlp:
```shell
pip install paddlenlp
```
注:安装PaddleNLP的目的是为了下载PaddleNLP中的数据集和Tokenizer。
#### 3.2 准备数据集
本案例默认以GLUE数据进行自动压缩实验,PaddleNLP会自动下载对应数据集。
#### 3.3 X2Paddle转换模型流程
**方式1: PyTorch2Paddle直接将Pytorch动态图模型转为Paddle静态图模型**
```shell
import torch
import numpy as np
# 将PyTorch模型设置为eval模式
torch_model.eval()
# 构建输入
input_ids = torch.unsqueeze(torch.tensor([0] * max_length), 0)
token_type_ids = torch.unsqueeze(torch.tensor([0] * max_length),0)
attention_msk = torch.unsqueeze(torch.tensor([0] * max_length),0)
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_model,
save_dir='./x2paddle_cola/',
jit_type="trace",
input_examples=[input_ids, attention_msk, token_type_ids])
```
PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图到Paddle动态图的转换,转换后的Paddle动态图运用动转静可转换为静态图模型。
- jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静,输入shape固定。
- jit_type为"script"时,当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。
注意:
- 由于自动压缩的是静态图模型,所以这里需要将```jit_type```设置为```trace```,并且注意PyTorch模型中需要设置```pad_to_max_length```,且设置的```max_length```需要和转换时构建的数据相同。
- HuggingFace默认输入```attention_mask```,PaddleNLP默认不输入,这里需要保持一致。可以PaddleNLP中设置```return_attention_mask=True```
- 使用PaddleNLP的tokenizer时需要在模型保存的文件夹中加入```model_config.json, special_tokens_map.json, tokenizer_config.json, vocab.txt```这些文件。
更多Pytorch2Paddle示例可参考[PyTorch模型转换文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md)。其他框架转换可参考[X2Paddle模型转换工具](https://github.com/PaddlePaddle/X2Paddle)
如想快速尝试运行实验,也可以直接下载已经转换好的模型,链接如下:
| [CoLA](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar) | [MRPC](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_mrpc.tar) | [QNLI](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qnli.tar) | [QQP](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qqp.tar) | [RTE](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_rte.tar) | [SST2](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_sst2.tar) |
```shell
wget https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar
tar xf x2paddle_cola.tar
```
**方式2: Onnx2Paddle将Pytorch动态图模型保存为Onnx格式后再转为Paddle静态图模型**
PyTorch 导出 ONNX 动态图模型
```shell
torch_model.eval()
input_ids = torch.unsqueeze(torch.tensor([0] * args.max_length), 0)
token_type_ids = torch.unsqueeze(torch.tensor([0] * args.max_length), 0)
attention_mask = torch.unsqueeze(torch.tensor([0] * args.max_length), 0)
input_names = ['input_ids', 'attention_mask', 'token_type_ids']
output_names = ['output']
torch.onnx.export(
model,
(input_ids, attention_mask, token_type_ids),
'model.onnx',
opset_version=11,
input_names=input_names,
output_names=output_names,
dynamic_axes={'input_ids': [0], 'attention_mask': [0], 'token_type_ids': [0]})
```
通过 X2Paddle 命令导出 Paddle 模型
```shell
x2paddle --framework=onnx --model=model.onnx --save_dir=pd_model_dynamic
```
在自动生成的 x2paddle_code.py 中添加如下代码:
```shell
def main(x0, x1, x2):
# x0, x1, x2 为模型输入.
paddle.disable_static()
params = paddle.load('model.pdparams')
model = BertForSequenceClassification()
model.set_dict(params)
model.eval()
## convert to jit
sepc_list = list()
sepc_list.append(
paddle.static.InputSpec(
shape=[-1, 128], name="x0", dtype="int64"),
paddle.static.InputSpec(
shape=[-1, 128], name="x1", dtype="int64"),
paddle.static.InputSpec(
shape=[-1, 128], name="x2", dtype="int64"))
static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, "./x2paddle_cola")
```
#### 3.4 自动压缩并产出模型
以“cola”任务为例,在配置文件“./config/cola.yaml”中配置推理模型路径、压缩策略参数等信息,并通过“--config_path”将配置文件传给示例脚本"run.py"。
在“run.py”中,调用接口```paddleslim.auto_compression.AutoCompression```加载配置文件,并对推理模型进行自动压缩。
```shell
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/'
```
## 4. 预测部署
- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md)
- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md)
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md)
## 5. FAQ
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_cola
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: cola
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 855
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.6006
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_mnli
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: mnli
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 1710
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.8318
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_mrpc
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: mrpc
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 915
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.8431
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_qnli
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: qnli
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 855
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.9068
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_qqp
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: qqp
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 855
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.9084
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_rte
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: rte
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 1240
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.6353
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_sst2
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: sst-2
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 3367
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.9163
Global:
input_names: ['x0', 'x1', 'x2']
model_dir: ./x2paddle_stsb
model_filename: model.pdmodel
params_filename: model.pdiparams
model_type: bert-base-cased
task_name: sts-b
dataset: glue
batch_size: 1
max_seq_length: 128
padding: max_length
return_attention_mask: True
TrainConfig:
epochs: 3
eval_iter: 1710
learning_rate: 1.0e-6
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.8846
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import sys
from functools import partial
import distutils.util
import numpy as np
import paddle
from paddle import inference
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad
from paddle.metric import Metric, Accuracy, Precision, Recall
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("sentence1", "sentence2"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("sentence1", "sentence2"),
"qqp": ("sentence1", "sentence2"),
"rte": ("sentence1", "sentence2"),
"sst-2": ("sentence", None),
"sts-b": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task_name",
default='cola',
type=str,
help="The name of the task to perform predict, selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--model_type",
default='bert-base-cased',
type=str,
help="Model type selected in bert.")
parser.add_argument(
"--model_name_or_path",
default='bert-base-cased',
type=str,
help="The directory or name of model.", )
parser.add_argument(
"--model_path",
default='./quant_models/model',
type=str,
required=True,
help="The path prefix of inference model to be used.", )
parser.add_argument(
"--device",
default="gpu",
choices=["gpu", "cpu", "xpu"],
help="Device selected for inference.", )
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size for predict.", )
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--padding",
default='max_length',
type=int,
help="Padding type", )
parser.add_argument(
"--perf_warmup_steps",
default=20,
type=int,
help="Warmup steps for performance test.", )
parser.add_argument(
"--use_trt",
action='store_true',
help="Whether to use inference engin TensorRT.", )
parser.add_argument(
"--perf",
action='store_false',
help="Whether to test performance.", )
parser.add_argument(
"--int8",
action='store_true',
help="Whether to use int8 inference.", )
parser.add_argument(
"--fp16",
action='store_true',
help="Whether to use int8 inference.", )
args = parser.parse_args()
return args
@paddle.no_grad()
def evaluate(outputs, metric, data_loader):
metric.reset()
for i, batch in enumerate(data_loader):
input_ids, segment_ids, labels = batch
logits = paddle.to_tensor(outputs[i][0])
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
print("acc: %s, " % res, end='')
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
task_name=None,
is_test=False,
padding='max_length',
return_attention_mask=True):
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
sentence1_key, sentence2_key = task_to_keys[task_name]
texts = ((example[sentence1_key], ) if sentence2_key is None else
(example[sentence1_key], example[sentence2_key]))
example = tokenizer(
*texts,
max_seq_len=max_seq_length,
padding=padding,
return_attention_mask=return_attention_mask)
if not is_test:
if return_attention_mask:
return example['input_ids'], example['attention_mask'], example[
'token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids'], label
else:
if return_attention_mask:
return example['input_ids'], example['attention_mask'], example[
'token_type_ids']
else:
return example['input_ids'], example['token_type_ids']
class Predictor(object):
def __init__(self, predictor, input_handles, output_handles):
self.predictor = predictor
self.input_handles = input_handles
self.output_handles = output_handles
@classmethod
def create_predictor(cls, args):
config = paddle.inference.Config(args.model_path + ".pdmodel",
args.model_path + ".pdiparams")
if args.device == "gpu":
# set GPU configs accordingly
config.enable_use_gpu(100, 0)
cls.device = paddle.set_device("gpu")
elif args.device == "cpu":
# set CPU configs accordingly,
# such as enable_mkldnn, set_cpu_math_library_num_threads
config.disable_gpu()
cls.device = paddle.set_device("cpu")
elif args.device == "xpu":
# set XPU configs accordingly
config.enable_xpu(100)
if args.use_trt:
if args.int8:
config.enable_tensorrt_engine(
workspace_size=1 << 30,
precision_mode=inference.PrecisionType.Int8,
max_batch_size=args.batch_size,
min_subgraph_size=5,
use_static=False,
use_calib_mode=False)
elif args.fp16:
config.enable_tensorrt_engine(
workspace_size=1 << 30,
precision_mode=inference.PrecisionType.Half,
max_batch_size=args.batch_size,
min_subgraph_size=5,
use_static=False,
use_calib_mode=False)
else:
config.enable_tensorrt_engine(
workspace_size=1 << 30,
precision_mode=inference.PrecisionType.Float32,
max_batch_size=args.batch_size,
min_subgraph_size=5,
use_static=False,
use_calib_mode=False)
print("Enable TensorRT is: {}".format(
config.tensorrt_engine_enabled()))
# Set min/max/opt tensor shape of each trt subgraph input according
# to dataset.
# For example, the config of TNEWS data should be 16, 32, 32, 31, 128, 32.
predictor = paddle.inference.create_predictor(config)
input_handles = [
predictor.get_input_handle(name)
for name in predictor.get_input_names()
]
output_handles = [
predictor.get_output_handle(name)
for name in predictor.get_output_names()
]
return cls(predictor, input_handles, output_handles)
def predict_batch(self, data):
for input_field, input_handle in zip(data, self.input_handles):
input_handle.copy_from_cpu(input_field)
self.predictor.run()
output = [
output_handle.copy_to_cpu() for output_handle in self.output_handles
]
return output
def convert_predict_batch(self, args, data, tokenizer, batchify_fn,
label_list):
examples = []
for example in data:
example = convert_example(
example,
tokenizer,
label_list,
task_name=args.task_name,
max_seq_length=args.max_seq_length,
padding='max_length',
return_attention_mask=True)
examples.append(example)
return examples
def predict(self, dataset, tokenizer, batchify_fn, args):
batches = [
dataset[idx:idx + args.batch_size]
for idx in range(0, len(dataset), args.batch_size)
]
if args.perf:
for i, batch in enumerate(batches):
examples = self.convert_predict_batch(
args, batch, tokenizer, batchify_fn, dataset.label_list)
input_ids, atten_mask, segment_ids, label = batchify_fn(
examples)
output = self.predict_batch(
[input_ids, atten_mask, segment_ids])
if i > args.perf_warmup_steps:
break
time1 = time.time()
for batch in batches:
examples = self.convert_predict_batch(
args, batch, tokenizer, batchify_fn, dataset.label_list)
input_ids, atten_mask, segment_ids, _ = batchify_fn(examples)
output = self.predict_batch(
[input_ids, atten_mask, segment_ids])
print("task name: %s, time: %s, " %
(args.task_name, time.time() - time1))
else:
metric = METRIC_CLASSES[args.task_name]()
metric.reset()
for i, batch in enumerate(batches):
examples = self.convert_predict_batch(
args, batch, tokenizer, batchify_fn, dataset.label_list)
input_ids, atten_mask, segment_ids, label = batchify_fn(
examples)
output = self.predict_batch(
[input_ids, atten_mask, segment_ids])
correct = metric.compute(
paddle.to_tensor(output), paddle.to_tensor(label))
metric.update(correct)
res = metric.accumulate()
print("task name: %s, acc: %s, " % (args.task_name, res), end='')
def main():
paddle.seed(42)
args = parse_args()
args.task_name = args.task_name.lower()
args.model_type = args.model_type.lower()
predictor = Predictor.create_predictor(args)
dev_ds = load_dataset('glue', args.task_name, splits='dev')
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=0),
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
Stack(dtype="int64" if dev_ds.label_list else "float32") # label
): fn(samples)
outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args)
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
import paddle.nn as nn
import functools
from functools import partial
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.metric import Metric, Accuracy
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression.compressor import AutoCompression
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
return parser
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("sentence1", "sentence2"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("sentence1", "sentence2"),
"qqp": ("sentence1", "sentence2"),
"rte": ("sentence1", "sentence2"),
"sst-2": ("sentence", None),
"sts-b": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False,
padding='max_length',
return_attention_mask=True):
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
sentence1_key, sentence2_key = task_to_keys[global_config['task_name']]
texts = ((example[sentence1_key], ) if sentence2_key is None else
(example[sentence1_key], example[sentence2_key]))
example = tokenizer(
*texts,
max_seq_len=max_seq_length,
padding=padding,
return_attention_mask=return_attention_mask,
truncation='longest_first')
if not is_test:
if return_attention_mask:
return example['input_ids'], example['attention_mask'], example[
'token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids'], label
else:
if return_attention_mask:
return example['input_ids'], example['attention_mask'], example[
'token_type_ids']
else:
return example['input_ids'], example['token_type_ids']
def create_data_holder(task_name, input_names):
"""
Define the input data holder for the glue task.
"""
inputs = []
for name in input_names:
inputs.append(
paddle.static.data(
name=name, shape=[-1, -1], dtype="int64"))
if task_name == "sts-b":
inputs.append(
paddle.static.data(
name="label", shape=[-1, 1], dtype="float32"))
else:
inputs.append(
paddle.static.data(
name="label", shape=[-1, 1], dtype="int64"))
return inputs
def reader():
# Create the tokenizer and dataset
tokenizer = BertTokenizer.from_pretrained(global_config['model_dir'])
train_ds = load_dataset(
global_config['dataset'], global_config['task_name'], splits="train")
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=global_config['max_seq_length'],
is_test=True,
padding=global_config['padding'],
return_attention_mask=global_config['return_attention_mask'])
train_ds = train_ds.map(trans_func, lazy=True)
if global_config['return_attention_mask']:
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=0), # attention_mask
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
): fn(samples)
else:
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
): fn(samples)
train_batch_sampler = paddle.io.BatchSampler(
train_ds, batch_size=global_config['batch_size'], shuffle=True)
feed_list = create_data_holder(global_config['task_name'],
global_config['input_names'])
train_data_loader = DataLoader(
dataset=train_ds,
feed_list=feed_list[:-1],
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=False)
dev_trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=global_config['max_seq_length'],
padding=global_config['padding'],
return_attention_mask=global_config['return_attention_mask'])
if global_config['return_attention_mask']:
dev_batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=0), # attention_mask
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
else:
dev_batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
if global_config['task_name'] == "mnli":
dev_ds_matched, dev_ds_mismatched = load_dataset(
global_config['dataset'],
global_config['task_name'],
splits=["dev_matched", "dev_mismatched"])
dev_ds_matched = dev_ds_matched.map(dev_trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(dev_trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_ds_matched,
batch_size=global_config['batch_size'],
shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
feed_list=feed_list,
num_workers=0,
return_list=False)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_ds_mismatched,
batch_size=global_config['batch_size'],
shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
feed_list=feed_list,
return_list=False)
return train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched
else:
dev_ds = load_dataset(
global_config['dataset'], global_config['task_name'], splits='dev')
dev_ds = dev_ds.map(dev_trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=global_config['batch_size'], shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=dev_batchify_fn,
num_workers=0,
feed_list=feed_list,
return_list=False)
return train_data_loader, dev_data_loader
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.reset()
for data in eval_dataloader():
logits = exe.run(compiled_test_program,
feed={
test_feed_names[0]: data[0]['x0'],
test_feed_names[1]: data[0]['x1'],
test_feed_names[2]: data[0]['x2']
},
fetch_list=test_fetch_list)
paddle.disable_static()
if isinstance(metric, PearsonAndSpearman):
labels_pd = paddle.to_tensor(np.array(data[0]['label'])).reshape(
(-1, 1))
logits_pd = paddle.to_tensor(logits[0]).reshape((-1, 1))
metric.update((logits_pd, labels_pd))
else:
labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten())
logits_pd = paddle.to_tensor(logits[0])
correct = metric.compute(logits_pd, labels_pd)
metric.update(correct)
paddle.enable_static()
res = metric.accumulate()
return res[0] if isinstance(res, list) or isinstance(res, tuple) else res
def apply_decay_param_fun(name):
if name.find("bias") > -1:
return True
elif name.find("b_0") > -1:
return True
elif name.find("norm") > -1:
return True
else:
return False
def main():
all_config = load_slim_config(args.config_path)
global global_config
assert "Global" in all_config, "Key Global not found in config file."
global_config = all_config["Global"]
if 'TrainConfig' in all_config:
all_config['TrainConfig']['optimizer_builder'][
'apply_decay_param_fun'] = apply_decay_param_fun
global train_dataloader, eval_dataloader
train_dataloader, eval_dataloader = reader()
global metric
metric_class = METRIC_CLASSES[global_config['task_name']]
metric = metric_class()
ac = AutoCompression(
model_dir=global_config['model_dir'],
model_filename=global_config['model_filename'],
params_filename=global_config['params_filename'],
save_dir=args.save_dir,
config=args.config_path,
train_dataloader=train_dataloader,
eval_callback=eval_function if
(len(list(all_config.keys())) == 2 and 'TrainConfig' in all_config) or
len(list(all_config.keys())) == 1 or
'HyperParameterOptimization' not in all_config else eval_dataloader,
eval_dataloader=eval_dataloader)
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
print_arguments(args)
main()
python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册