From b3499dc9be9dd9ccdf82ab4456930ce01bd30ee9 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Mon, 27 Jun 2022 19:17:31 +0800 Subject: [PATCH] Add X2Paddle & NLP Demo (#1190) --- .../pytorch-huggingface/README.md | 192 ++++++++++ .../pytorch-huggingface/configs/cola.yaml | 22 ++ .../pytorch-huggingface/configs/mnli.yaml | 22 ++ .../pytorch-huggingface/configs/mrpc.yaml | 22 ++ .../pytorch-huggingface/configs/qnli.yaml | 22 ++ .../pytorch-huggingface/configs/qqp.yaml | 22 ++ .../pytorch-huggingface/configs/rte.yaml | 22 ++ .../pytorch-huggingface/configs/sst2.yaml | 22 ++ .../pytorch-huggingface/configs/stsb.yaml | 22 ++ .../pytorch-huggingface/infer.py | 335 ++++++++++++++++++ .../pytorch-huggingface/run.py | 324 +++++++++++++++++ .../pytorch-huggingface/run.sh | 1 + 12 files changed, 1028 insertions(+) create mode 100644 demo/auto_compression/pytorch-huggingface/README.md create mode 100644 demo/auto_compression/pytorch-huggingface/configs/cola.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/mnli.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/mrpc.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/qnli.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/qqp.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/rte.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/sst2.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/configs/stsb.yaml create mode 100644 demo/auto_compression/pytorch-huggingface/infer.py create mode 100644 demo/auto_compression/pytorch-huggingface/run.py create mode 100644 demo/auto_compression/pytorch-huggingface/run.sh diff --git a/demo/auto_compression/pytorch-huggingface/README.md b/demo/auto_compression/pytorch-huggingface/README.md new file mode 100644 index 00000000..bc35525e --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/README.md @@ -0,0 +1,192 @@ +# HuggingFace 预训练模型压缩部署示例 +目录: +- [1. 简介](#1简介) +- [2. Benchmark](#2Benchmark) +- [3. 自动压缩流程](#自动压缩流程) + - [3.1 准备环境](#31-准备环境) + - [3.2 准备数据集](#32-准备数据集) + - [3.3 X2Paddle转换模型流程](#33-X2Paddle转换模型流程) + - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) +- [4. 压缩配置介绍](#4压缩配置介绍) +- [5. 预测部署](#5预测部署) +- [6. FAQ](6FAQ) + +## 1. 简介 +飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。 + + +本示例将以[Pytorch](https://github.com/pytorch/pytorch)框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用[huggingface](https://github.com/huggingface/transformers)开源transformers库,将Pytorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(```Post-training quantization```)。 + + + + +## 2. Benchmark +[BERT](https://arxiv.org/abs/1810.04805) (```Bidirectional Encoder Representations from Transformers```)以Transformer 编码器为网络基本组件,使用掩码语言模型(```Masked Language Model```)和邻接句子预测(```Next Sentence Prediction```)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的NLP任务,效果通常也较直接在下游的任务上训练的模型更优。此前BERT即在[GLUE](https://gluebenchmark.com/tasks)评测任务上取得了SOTA的结果。 + +基于bert-base-cased模型,压缩前后的精度如下: +| 模型 | 策略 | CoLA | MRPC | QNLI | QQP | RTE | SST2 | AVG | +|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|:------:|:------:| +| bert-base-cased | Base模型| 60.06 | 84.31 | 90.68 | 90.84 | 63.53 | 91.63 | 80.17 | +| bert-base-cased |剪枝蒸馏+离线量化| 60.52 | 84.80 | 90.59 | 90.42 | 64.26 | 91.63 | 80.37 | + +模型在多个任务上平均精度以及加速对比如下: +| bert-base-cased | Accuracy(avg) | 时延(ms) | 加速比 | +|:-------:|:----------:|:------------:| :------:| +| 压缩前 | 80.17 | 8.18 | - | +| 压缩后 | 80.37 | 6.35 | 28.82% | + +- Nvidia GPU 测试环境: + - 硬件:NVIDIA Tesla T4 单卡 + - 软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4 + - 测试配置:batch_size: 1, seqence length: 128 + +## 3. 自动压缩流程 +#### 3.1 准备环境 +- python >= 3.6 +- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim develop版本或PaddleSlim>=2.3.0 +- X2Paddle develop版本 +- PaddleNLP >= 2.3 +- tensorflow == 1.14 (如需压缩TensorFlow模型) +- onnx >= 1.6.0 (如需压缩ONNX模型) +- torch >= 1.5.0 (如需压缩PyTorch模型) + +安装paddlepaddle: +```shell +# CPU +pip install paddlepaddle +# GPU +pip install paddlepaddle-gpu +``` + +安装paddleslim: +```shell +git clone https://github.com/PaddlePaddle/PaddleSlim.git +python setup.py install +``` + +安装X2Paddle: +``` +git clone https://github.com/PaddlePaddle/X2Paddle.git +cd X2Paddle +git checkout develop +python setup.py install +``` + +安装paddlenlp: +```shell +pip install paddlenlp +``` + +注:安装PaddleNLP的目的是为了下载PaddleNLP中的数据集和Tokenizer。 + + +#### 3.2 准备数据集 +本案例默认以GLUE数据进行自动压缩实验,PaddleNLP会自动下载对应数据集。 + + +#### 3.3 X2Paddle转换模型流程 + +**方式1: PyTorch2Paddle直接将Pytorch动态图模型转为Paddle静态图模型** + +```shell +import torch +import numpy as np +# 将PyTorch模型设置为eval模式 +torch_model.eval() +# 构建输入 +input_ids = torch.unsqueeze(torch.tensor([0] * max_length), 0) +token_type_ids = torch.unsqueeze(torch.tensor([0] * max_length),0) +attention_msk = torch.unsqueeze(torch.tensor([0] * max_length),0) +# 进行转换 +from x2paddle.convert import pytorch2paddle +pytorch2paddle(torch_model, + save_dir='./x2paddle_cola/', + jit_type="trace", + input_examples=[input_ids, attention_msk, token_type_ids]) +``` + +PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图到Paddle动态图的转换,转换后的Paddle动态图运用动转静可转换为静态图模型。 +- jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静,输入shape固定。 +- jit_type为"script"时,当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 + +注意: +- 由于自动压缩的是静态图模型,所以这里需要将```jit_type```设置为```trace```,并且注意PyTorch模型中需要设置```pad_to_max_length```,且设置的```max_length```需要和转换时构建的数据相同。 +- HuggingFace默认输入```attention_mask```,PaddleNLP默认不输入,这里需要保持一致。可以PaddleNLP中设置```return_attention_mask=True```。 +- 使用PaddleNLP的tokenizer时需要在模型保存的文件夹中加入```model_config.json, special_tokens_map.json, tokenizer_config.json, vocab.txt```这些文件。 + + +更多Pytorch2Paddle示例可参考[PyTorch模型转换文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md)。其他框架转换可参考[X2Paddle模型转换工具](https://github.com/PaddlePaddle/X2Paddle) + +如想快速尝试运行实验,也可以直接下载已经转换好的模型,链接如下: +| [CoLA](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar) | [MRPC](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_mrpc.tar) | [QNLI](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qnli.tar) | [QQP](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qqp.tar) | [RTE](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_rte.tar) | [SST2](https://paddle-slim-models.bj.bcebos.com/act/x2paddle_sst2.tar) | + +```shell +wget https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar +tar xf x2paddle_cola.tar +``` + +**方式2: Onnx2Paddle将Pytorch动态图模型保存为Onnx格式后再转为Paddle静态图模型** + + +PyTorch 导出 ONNX 动态图模型 +```shell +torch_model.eval() +input_ids = torch.unsqueeze(torch.tensor([0] * args.max_length), 0) +token_type_ids = torch.unsqueeze(torch.tensor([0] * args.max_length), 0) +attention_mask = torch.unsqueeze(torch.tensor([0] * args.max_length), 0) +input_names = ['input_ids', 'attention_mask', 'token_type_ids'] +output_names = ['output'] +torch.onnx.export( + model, + (input_ids, attention_mask, token_type_ids), + 'model.onnx', + opset_version=11, + input_names=input_names, + output_names=output_names, + dynamic_axes={'input_ids': [0], 'attention_mask': [0], 'token_type_ids': [0]}) +``` + +通过 X2Paddle 命令导出 Paddle 模型 +```shell +x2paddle --framework=onnx --model=model.onnx --save_dir=pd_model_dynamic +``` + +在自动生成的 x2paddle_code.py 中添加如下代码: +```shell +def main(x0, x1, x2): + # x0, x1, x2 为模型输入. + paddle.disable_static() + params = paddle.load('model.pdparams') + model = BertForSequenceClassification() + model.set_dict(params) + model.eval() + ## convert to jit + sepc_list = list() + sepc_list.append( + paddle.static.InputSpec( + shape=[-1, 128], name="x0", dtype="int64"), + paddle.static.InputSpec( + shape=[-1, 128], name="x1", dtype="int64"), + paddle.static.InputSpec( + shape=[-1, 128], name="x2", dtype="int64")) + static_model = paddle.jit.to_static(model, input_spec=sepc_list) + paddle.jit.save(static_model, "./x2paddle_cola") +``` + + +#### 3.4 自动压缩并产出模型 +以“cola”任务为例,在配置文件“./config/cola.yaml”中配置推理模型路径、压缩策略参数等信息,并通过“--config_path”将配置文件传给示例脚本"run.py"。 +在“run.py”中,调用接口```paddleslim.auto_compression.AutoCompression```加载配置文件,并对推理模型进行自动压缩。 +```shell +export CUDA_VISIBLE_DEVICES=0 +python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/' +``` + +## 4. 预测部署 + +- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md) +- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md) +- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md) + +## 5. FAQ diff --git a/demo/auto_compression/pytorch-huggingface/configs/cola.yaml b/demo/auto_compression/pytorch-huggingface/configs/cola.yaml new file mode 100644 index 00000000..d6a06e47 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/cola.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_cola + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: cola + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 855 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.6006 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/mnli.yaml b/demo/auto_compression/pytorch-huggingface/configs/mnli.yaml new file mode 100644 index 00000000..5a1e7515 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/mnli.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_mnli + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: mnli + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 1710 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.8318 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/mrpc.yaml b/demo/auto_compression/pytorch-huggingface/configs/mrpc.yaml new file mode 100644 index 00000000..86f997be --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/mrpc.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_mrpc + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: mrpc + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 915 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.8431 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/qnli.yaml b/demo/auto_compression/pytorch-huggingface/configs/qnli.yaml new file mode 100644 index 00000000..321a0463 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/qnli.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_qnli + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: qnli + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 855 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.9068 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/qqp.yaml b/demo/auto_compression/pytorch-huggingface/configs/qqp.yaml new file mode 100644 index 00000000..21676e0a --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/qqp.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_qqp + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: qqp + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 855 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.9084 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/rte.yaml b/demo/auto_compression/pytorch-huggingface/configs/rte.yaml new file mode 100644 index 00000000..70879b5d --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/rte.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_rte + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: rte + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 1240 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.6353 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/sst2.yaml b/demo/auto_compression/pytorch-huggingface/configs/sst2.yaml new file mode 100644 index 00000000..3f9a6f53 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/sst2.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_sst2 + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: sst-2 + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 3367 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.9163 + diff --git a/demo/auto_compression/pytorch-huggingface/configs/stsb.yaml b/demo/auto_compression/pytorch-huggingface/configs/stsb.yaml new file mode 100644 index 00000000..2abc207b --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/configs/stsb.yaml @@ -0,0 +1,22 @@ +Global: + input_names: ['x0', 'x1', 'x2'] + model_dir: ./x2paddle_stsb + model_filename: model.pdmodel + params_filename: model.pdiparams + model_type: bert-base-cased + task_name: sts-b + dataset: glue + batch_size: 1 + max_seq_length: 128 + padding: max_length + return_attention_mask: True +TrainConfig: + epochs: 3 + eval_iter: 1710 + learning_rate: 1.0e-6 + optimizer_builder: + optimizer: + type: AdamW + weight_decay: 0.01 + origin_metric: 0.8846 + diff --git a/demo/auto_compression/pytorch-huggingface/infer.py b/demo/auto_compression/pytorch-huggingface/infer.py new file mode 100644 index 00000000..0d95fc2a --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/infer.py @@ -0,0 +1,335 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +import sys +from functools import partial +import distutils.util +import numpy as np + +import paddle +from paddle import inference +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad +from paddle.metric import Metric, Accuracy, Precision, Recall +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer + +METRIC_CLASSES = { + "cola": Mcc, + "sst-2": Accuracy, + "mrpc": AccuracyAndF1, + "sts-b": PearsonAndSpearman, + "qqp": AccuracyAndF1, + "mnli": Accuracy, + "qnli": Accuracy, + "rte": Accuracy, +} + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("sentence1", "sentence2"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("sentence1", "sentence2"), + "qqp": ("sentence1", "sentence2"), + "rte": ("sentence1", "sentence2"), + "sst-2": ("sentence", None), + "sts-b": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default='cola', + type=str, + help="The name of the task to perform predict, selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default='bert-base-cased', + type=str, + help="Model type selected in bert.") + parser.add_argument( + "--model_name_or_path", + default='bert-base-cased', + type=str, + help="The directory or name of model.", ) + parser.add_argument( + "--model_path", + default='./quant_models/model', + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.", ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size for predict.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--padding", + default='max_length', + type=int, + help="Padding type", ) + parser.add_argument( + "--perf_warmup_steps", + default=20, + type=int, + help="Warmup steps for performance test.", ) + parser.add_argument( + "--use_trt", + action='store_true', + help="Whether to use inference engin TensorRT.", ) + parser.add_argument( + "--perf", + action='store_false', + help="Whether to test performance.", ) + parser.add_argument( + "--int8", + action='store_true', + help="Whether to use int8 inference.", ) + + parser.add_argument( + "--fp16", + action='store_true', + help="Whether to use int8 inference.", ) + args = parser.parse_args() + return args + + +@paddle.no_grad() +def evaluate(outputs, metric, data_loader): + metric.reset() + for i, batch in enumerate(data_loader): + input_ids, segment_ids, labels = batch + logits = paddle.to_tensor(outputs[i][0]) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("acc: %s, " % res, end='') + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + task_name=None, + is_test=False, + padding='max_length', + return_attention_mask=True): + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['labels'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + sentence1_key, sentence2_key = task_to_keys[task_name] + texts = ((example[sentence1_key], ) if sentence2_key is None else + (example[sentence1_key], example[sentence2_key])) + example = tokenizer( + *texts, + max_seq_len=max_seq_length, + padding=padding, + return_attention_mask=return_attention_mask) + if not is_test: + if return_attention_mask: + return example['input_ids'], example['attention_mask'], example[ + 'token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'], label + else: + if return_attention_mask: + return example['input_ids'], example['attention_mask'], example[ + 'token_type_ids'] + else: + return example['input_ids'], example['token_type_ids'] + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.model_path + ".pdmodel", + args.model_path + ".pdiparams") + if args.device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + cls.device = paddle.set_device("gpu") + elif args.device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + cls.device = paddle.set_device("cpu") + elif args.device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + if args.use_trt: + if args.int8: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Int8, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + elif args.fp16: + + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Half, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + else: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + print("Enable TensorRT is: {}".format( + config.tensorrt_engine_enabled())) + # Set min/max/opt tensor shape of each trt subgraph input according + # to dataset. + # For example, the config of TNEWS data should be 16, 32, 32, 31, 128, 32. + predictor = paddle.inference.create_predictor(config) + + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_output_handle(name) + for name in predictor.get_output_names() + ] + + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + return output + + def convert_predict_batch(self, args, data, tokenizer, batchify_fn, + label_list): + examples = [] + for example in data: + example = convert_example( + example, + tokenizer, + label_list, + task_name=args.task_name, + max_seq_length=args.max_seq_length, + padding='max_length', + return_attention_mask=True) + examples.append(example) + + return examples + + def predict(self, dataset, tokenizer, batchify_fn, args): + batches = [ + dataset[idx:idx + args.batch_size] + for idx in range(0, len(dataset), args.batch_size) + ] + if args.perf: + for i, batch in enumerate(batches): + examples = self.convert_predict_batch( + args, batch, tokenizer, batchify_fn, dataset.label_list) + input_ids, atten_mask, segment_ids, label = batchify_fn( + examples) + output = self.predict_batch( + [input_ids, atten_mask, segment_ids]) + if i > args.perf_warmup_steps: + break + time1 = time.time() + for batch in batches: + examples = self.convert_predict_batch( + args, batch, tokenizer, batchify_fn, dataset.label_list) + input_ids, atten_mask, segment_ids, _ = batchify_fn(examples) + output = self.predict_batch( + [input_ids, atten_mask, segment_ids]) + + print("task name: %s, time: %s, " % + (args.task_name, time.time() - time1)) + + else: + metric = METRIC_CLASSES[args.task_name]() + metric.reset() + for i, batch in enumerate(batches): + examples = self.convert_predict_batch( + args, batch, tokenizer, batchify_fn, dataset.label_list) + input_ids, atten_mask, segment_ids, label = batchify_fn( + examples) + output = self.predict_batch( + [input_ids, atten_mask, segment_ids]) + correct = metric.compute( + paddle.to_tensor(output), paddle.to_tensor(label)) + metric.update(correct) + + res = metric.accumulate() + print("task name: %s, acc: %s, " % (args.task_name, res), end='') + + +def main(): + paddle.seed(42) + args = parse_args() + + args.task_name = args.task_name.lower() + args.model_type = args.model_type.lower() + + predictor = Predictor.create_predictor(args) + + dev_ds = load_dataset('glue', args.task_name, splits='dev') + + tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=0), + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args) + + +if __name__ == "__main__": + main() diff --git a/demo/auto_compression/pytorch-huggingface/run.py b/demo/auto_compression/pytorch-huggingface/run.py new file mode 100644 index 00000000..8b07b408 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/run.py @@ -0,0 +1,324 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +import paddle.nn as nn +import functools +from functools import partial +from paddle.io import Dataset, BatchSampler, DataLoader +from paddle.metric import Metric, Accuracy +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.auto_compression.compressor import AutoCompression + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + return parser + + +def print_arguments(args): + print('----------- Running Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------') + + +METRIC_CLASSES = { + "cola": Mcc, + "sst-2": Accuracy, + "mrpc": AccuracyAndF1, + "sts-b": PearsonAndSpearman, + "qqp": AccuracyAndF1, + "mnli": Accuracy, + "qnli": Accuracy, + "rte": Accuracy, +} + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("sentence1", "sentence2"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("sentence1", "sentence2"), + "qqp": ("sentence1", "sentence2"), + "rte": ("sentence1", "sentence2"), + "sst-2": ("sentence", None), + "sts-b": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False, + padding='max_length', + return_attention_mask=True): + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['labels'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + sentence1_key, sentence2_key = task_to_keys[global_config['task_name']] + texts = ((example[sentence1_key], ) if sentence2_key is None else + (example[sentence1_key], example[sentence2_key])) + example = tokenizer( + *texts, + max_seq_len=max_seq_length, + padding=padding, + return_attention_mask=return_attention_mask, + truncation='longest_first') + if not is_test: + if return_attention_mask: + return example['input_ids'], example['attention_mask'], example[ + 'token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'], label + else: + if return_attention_mask: + return example['input_ids'], example['attention_mask'], example[ + 'token_type_ids'] + else: + return example['input_ids'], example['token_type_ids'] + + +def create_data_holder(task_name, input_names): + """ + Define the input data holder for the glue task. + """ + inputs = [] + for name in input_names: + inputs.append( + paddle.static.data( + name=name, shape=[-1, -1], dtype="int64")) + + if task_name == "sts-b": + inputs.append( + paddle.static.data( + name="label", shape=[-1, 1], dtype="float32")) + else: + inputs.append( + paddle.static.data( + name="label", shape=[-1, 1], dtype="int64")) + + return inputs + + +def reader(): + # Create the tokenizer and dataset + tokenizer = BertTokenizer.from_pretrained(global_config['model_dir']) + train_ds = load_dataset( + global_config['dataset'], global_config['task_name'], splits="train") + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=global_config['max_seq_length'], + is_test=True, + padding=global_config['padding'], + return_attention_mask=global_config['return_attention_mask']) + + train_ds = train_ds.map(trans_func, lazy=True) + if global_config['return_attention_mask']: + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=0), # attention_mask + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type + ): fn(samples) + else: + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type + ): fn(samples) + + train_batch_sampler = paddle.io.BatchSampler( + train_ds, batch_size=global_config['batch_size'], shuffle=True) + + feed_list = create_data_holder(global_config['task_name'], + global_config['input_names']) + train_data_loader = DataLoader( + dataset=train_ds, + feed_list=feed_list[:-1], + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=False) + + dev_trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=global_config['max_seq_length'], + padding=global_config['padding'], + return_attention_mask=global_config['return_attention_mask']) + + if global_config['return_attention_mask']: + dev_batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=0), # attention_mask + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + else: + dev_batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + + if global_config['task_name'] == "mnli": + dev_ds_matched, dev_ds_mismatched = load_dataset( + global_config['dataset'], + global_config['task_name'], + splits=["dev_matched", "dev_mismatched"]) + dev_ds_matched = dev_ds_matched.map(dev_trans_func, lazy=True) + dev_ds_mismatched = dev_ds_mismatched.map(dev_trans_func, lazy=True) + dev_batch_sampler_matched = paddle.io.BatchSampler( + dev_ds_matched, + batch_size=global_config['batch_size'], + shuffle=False) + dev_data_loader_matched = DataLoader( + dataset=dev_ds_matched, + batch_sampler=dev_batch_sampler_matched, + collate_fn=batchify_fn, + feed_list=feed_list, + num_workers=0, + return_list=False) + dev_batch_sampler_mismatched = paddle.io.BatchSampler( + dev_ds_mismatched, + batch_size=global_config['batch_size'], + shuffle=False) + dev_data_loader_mismatched = DataLoader( + dataset=dev_ds_mismatched, + batch_sampler=dev_batch_sampler_mismatched, + collate_fn=batchify_fn, + num_workers=0, + feed_list=feed_list, + return_list=False) + return train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched + else: + dev_ds = load_dataset( + global_config['dataset'], global_config['task_name'], splits='dev') + dev_ds = dev_ds.map(dev_trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=global_config['batch_size'], shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=dev_batchify_fn, + num_workers=0, + feed_list=feed_list, + return_list=False) + return train_data_loader, dev_data_loader + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + metric.reset() + for data in eval_dataloader(): + logits = exe.run(compiled_test_program, + feed={ + test_feed_names[0]: data[0]['x0'], + test_feed_names[1]: data[0]['x1'], + test_feed_names[2]: data[0]['x2'] + }, + fetch_list=test_fetch_list) + paddle.disable_static() + if isinstance(metric, PearsonAndSpearman): + labels_pd = paddle.to_tensor(np.array(data[0]['label'])).reshape( + (-1, 1)) + logits_pd = paddle.to_tensor(logits[0]).reshape((-1, 1)) + metric.update((logits_pd, labels_pd)) + else: + labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten()) + logits_pd = paddle.to_tensor(logits[0]) + correct = metric.compute(logits_pd, labels_pd) + metric.update(correct) + paddle.enable_static() + res = metric.accumulate() + return res[0] if isinstance(res, list) or isinstance(res, tuple) else res + + +def apply_decay_param_fun(name): + if name.find("bias") > -1: + return True + elif name.find("b_0") > -1: + return True + elif name.find("norm") > -1: + return True + else: + return False + + +def main(): + all_config = load_slim_config(args.config_path) + + global global_config + assert "Global" in all_config, "Key Global not found in config file." + global_config = all_config["Global"] + + if 'TrainConfig' in all_config: + all_config['TrainConfig']['optimizer_builder'][ + 'apply_decay_param_fun'] = apply_decay_param_fun + + global train_dataloader, eval_dataloader + train_dataloader, eval_dataloader = reader() + + global metric + metric_class = METRIC_CLASSES[global_config['task_name']] + metric = metric_class() + + ac = AutoCompression( + model_dir=global_config['model_dir'], + model_filename=global_config['model_filename'], + params_filename=global_config['params_filename'], + save_dir=args.save_dir, + config=args.config_path, + train_dataloader=train_dataloader, + eval_callback=eval_function if + (len(list(all_config.keys())) == 2 and 'TrainConfig' in all_config) or + len(list(all_config.keys())) == 1 or + 'HyperParameterOptimization' not in all_config else eval_dataloader, + eval_dataloader=eval_dataloader) + + ac.compress() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + args = parser.parse_args() + print_arguments(args) + main() diff --git a/demo/auto_compression/pytorch-huggingface/run.sh b/demo/auto_compression/pytorch-huggingface/run.sh new file mode 100644 index 00000000..63958697 --- /dev/null +++ b/demo/auto_compression/pytorch-huggingface/run.sh @@ -0,0 +1 @@ +python run.py --config_path=./configs/cola.yaml --save_dir='./output/cola/' -- GitLab