未验证 提交 2a4270b0 编写于 作者: C ceci3 提交者: GitHub

[cherry pick] Add docs for ofa demo (#615)

* OFA demo for ernie (#493)
上级 6181003b
# OFA压缩PaddleNLP-BERT模型
BERT-base模型是一个迁移能力很强的通用语义表示模型,但是模型中也有一些参数冗余。本教程将介绍如何使用PaddleSlim对[PaddleNLP](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/)中BERT-base模型进行压缩。
## 1. 压缩结果
利用`bert-base-uncased`模型首先在GLUE数据集上进行finetune,得到需要压缩的模型,之后基于此模型进行压缩。压缩后模型参数大小减小26%(从110M减少到81M),压缩后模型在GLUE dev数据集上的精度和压缩前模型在GLUE dev数据集上的精度对比如下表所示:
| Task | Metric | Baseline | Result with PaddleSlim |
|:-----:|:----------------------------:|:-----------------:|:----------------------:|
| SST-2 | Accuracy | 0.93005 | 0.931193 |
| QNLI | Accuracy | 0.91781 | 0.920740 |
| CoLA | Mattehew's corr | 0.59557 | 0.601244 |
| MRPC | F1/Accuracy | 0.91667/0.88235 | 0.91740/0.88480 |
| STS-B | Person/Spearman corr | 0.88847/0.88350 | 0.89271/0.88958 |
| QQP | Accuracy/F1 | 0.90581/0.87347 | 0.90994/0.87947 |
| MNLI | Matched acc/MisMatched acc | 0.84422/0.84825 | 0.84687/0.85242 |
| RTE | Accuracy | 0.711191 | 0.718412 |
<p align="center">
<strong>表1-1: GLUE数据集精度对比</strong>
</p>
压缩前后模型的耗时如下表所示:
<table style="width:100%;" cellpadding="2" cellspacing="0" border="1" bordercolor="#000000">
<tbody>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Device</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">Batch Size</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">Model</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">TRT(FP16)</span>
</td>
<td style="text-align:center;">
<span style="font-size:18px;">Latency(ms)</span>
</td>
</tr>
<tr>
<td rowspan=4 align=center> T4 </td>
<td rowspan=4 align=center> 16 </td>
<td rowspan=2 align=center> BERT </td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">110.71</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px">Y</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">22.0</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center>Compressed BERT </td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">69.62</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px">Y</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">14.93</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center> V100 </td>
<td rowspan=2 align=center> 16 </td>
<td style="text-align:center">
<span style="font-size:18px;">BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">33.28</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Compressed BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">21.83</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center> Intel(R) Xeon(R) Gold 5117 CPU @ 2.00GHz </td>
<td rowspan=2 align=center> 16 </td>
<td style="text-align:center">
<span style="font-size:18px;">BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">10831.73</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Compressed BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">7682.93</span>
</td>
</tr>
</tbody>
</table>
<br />
<p align="center">
<strong>表1-2: 模型速度对比</strong>
</p>
压缩后模型在T4机器上相比原始模型在FP32的情况下加速59%,在TensorRT FP16的情况下加速47.3%。
压缩后模型在V100机器上相比原始模型在FP32的情况下加速52.5%。
压缩后模型在Intel(R) Xeon(R) Gold 5117 CPU上相比原始模型在FP32的情况下加速41%。
## 2. 快速开始
本教程示例以GLUE/SST-2 数据集为例。
### 2.1 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP2.0beta及之后版本和Paddle2.0rc1及之后版本。
```shell
pip install paddlenlp
pip install paddlepaddle_gpu>=2.0rc1
```
### 2.2 Fine-tuing
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考[Fine-tuning教程](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/bert)
Fine-tuning 在dev上的结果如压缩结果表1-1『Baseline』那一列所示。
### 2.3 压缩训练
```python
python -u ./run_glue_ofa.py --model_type bert \
--model_name_or_path ${task_pretrained_model_dir} \
--task_name $TASK_NAME --max_seq_length 128 \
--batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 6 \
--logging_steps 10 \
--save_steps 100 \
--output_dir ./tmp/$TASK_NAME \
--n_gpu 1 \
--width_mult_list 1.0 0.8333333333333334 0.6666666666666666 0.5
```
其中参数释义如下:
- `model_type` 指示了模型类型,当前仅支持BERT模型。
- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `task_name` 表示 Fine-tuning 的任务。
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `num_train_epochs` 表示训练轮数。
- `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。
- `output_dir` 表示模型保存路径。
- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。
- `width_mult_list` 表示压缩训练过程中,对每层Transformer Block的宽度选择的范围。
压缩训练之后在dev上的结果如表1-1中『Result with PaddleSlim』列所示,延时情况如表1-2所示。
## 3. OFA接口介绍
TODO
# OFA压缩ERNIE模型
ERNIE是百度开创性提出的基于知识增强的持续学习语义理解框架,该框架将大数据预训练与多源丰富知识相结合,通过持续学习技术,不断吸收海量文本数据中词汇、结构、语义等方面的知识,实现模型效果不断进化。本教程讲介绍如何使用PaddleSlim对[ERNIE](https://github.com/PaddlePaddle/ERNIE)模型进行压缩。
使用本教程压缩算法可以在精度无损的情况下,对原始Tiny-ERNIE模型进行40%的加速。
## 1. 快速开始
本教程以 CLUE/XNLI 数据集为例。
### 1.1 安装依赖项
由于ERNIE repo中动态图模型是基于Paddle 1.8.5版本进行开发的,所以本教程依赖Paddle 1.8.5和Paddle-ERNIE 0.0.4.dev1.
```shell
pip install paddle-ernie==0.0.4.dev1
pip install paddlepaddle_gpu==1.8.5.post97
```
propeller是ERNIE框架中辅助模型训练的高级框架,包含NLP常用的前、后处理流程。你可以通过将ERNIE repo根目录放入PYTHONPATH的方式导入propeller:
```shell
git clone https://github.com/PaddlePaddle/ERNIE
cd ERNIE
export PYTHONPATH=$PWD:$PYTHONPATH
```
### 1.2 Fine-tuning
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。参考[Fine-tuning教程](https://github.com/PaddlePaddle/ERNIE/tree/v2.4.0#%E6%94%AF%E6%8C%81%E7%9A%84nlp%E4%BB%BB%E5%8A%A1)得到Tiny-ERNIE模型在XNLI数据集上的Fine-tuning模型.
### 1.3 压缩训练
```python
python ./ofa_ernie.py \
--from_pretrained ernie-tiny \
--data_dir ./data/xnli \
--width_mult_list 1.0 0.75 0.5 0.25 \
--depth_mult_list 1.0 0.75
```
其中参数释义如下:
- `from_pretrained` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `data_dir` 指明数据保存目录。
- `width_mult_list` 表示压缩训练过程中,对每层Transformer Block的宽度选择的范围。
- `depth_mult_list` 表示压缩训练过程中,模型包含的Transformer Block数量的选择的范围。
## 2. OFA接口介绍
TODO
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
from ernie_supernet.modeling_ernie_supernet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import sys
import os
import math
import argparse
import json
import logging
import logging
from functools import partial
import six
import paddle.fluid.dygraph as D
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie.file_utils import _fetch_from_remote
from ernie.modeling_ernie import AttentionLayer, ErnieBlock, ErnieModel, ErnieEncoderStack, ErnieModelForSequenceClassification
log = logging.getLogger(__name__)
def append_name(name, postfix):
if name is None:
return None
elif name == '':
return postfix
else:
return '%s_%s' % (name, postfix)
def _attn_forward(self,
queries,
keys,
values,
attn_bias,
past_cache,
head_mask=None):
assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3
q = self.q(queries)
k = self.k(keys)
v = self.v(values)
cache = (k, v)
if past_cache is not None:
cached_k, cached_v = past_cache
k = L.concat([cached_k, k], 1)
v = L.concat([cached_v, v], 1)
if hasattr(self.q, 'fn') and self.q.fn.cur_config['expand_ratio'] != None:
n_head = int(self.n_head * self.q.fn.cur_config['expand_ratio'])
else:
n_head = self.n_head
q = L.transpose(
L.reshape(q, [0, 0, n_head, q.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
k = L.transpose(
L.reshape(k, [0, 0, n_head, k.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
v = L.transpose(
L.reshape(v, [0, 0, n_head, v.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
q = L.scale(q, scale=self.d_key**-0.5)
score = L.matmul(q, k, transpose_y=True)
if attn_bias is not None:
score += attn_bias
score = L.softmax(score, use_cudnn=True)
score = self.dropout(score)
if head_mask is not None:
score = score * head_mask
out = L.matmul(score, v)
out = L.transpose(out, [0, 2, 1, 3])
out = L.reshape(out, [0, 0, out.shape[2] * out.shape[3]])
out = self.o(out)
return out, cache
AttentionLayer.forward = _attn_forward
def _ernie_block_stack_forward(self,
inputs,
attn_bias=None,
past_cache=None,
num_layers=12,
depth_mult=1.,
head_mask=None):
if past_cache is not None:
assert isinstance(
past_cache, tuple
), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr(
type(past_cache))
past_cache = list(zip(*past_cache))
else:
past_cache = [None] * len(self.block)
cache_list_k, cache_list_v, hidden_list = [], [], [inputs]
depth = round(num_layers * depth_mult)
kept_layers_index = []
for i in range(1, depth + 1):
kept_layers_index.append(math.floor(i / depth_mult) - 1)
for i in kept_layers_index:
b = self.block[i]
p = past_cache[i]
inputs, cache = b(inputs,
attn_bias=attn_bias,
past_cache=p,
head_mask=head_mask[i])
cache_k, cache_v = cache
cache_list_k.append(cache_k)
cache_list_v.append(cache_v)
hidden_list.append(inputs)
return inputs, hidden_list, (cache_list_k, cache_list_v)
ErnieEncoderStack.forward = _ernie_block_stack_forward
def _ernie_block_forward(self,
inputs,
attn_bias=None,
past_cache=None,
head_mask=None):
attn_out, cache = self.attn(
inputs,
inputs,
inputs,
attn_bias,
past_cache=past_cache,
head_mask=head_mask) #self attn
attn_out = self.dropout(attn_out)
hidden = attn_out + inputs
hidden = self.ln1(hidden) # dropout/ add/ norm
ffn_out = self.ffn(hidden)
ffn_out = self.dropout(ffn_out)
hidden = ffn_out + hidden
hidden = self.ln2(hidden)
return hidden, cache
ErnieBlock.forward = _ernie_block_forward
def _ernie_model_forward(self,
src_ids,
sent_ids=None,
pos_ids=None,
input_mask=None,
attn_bias=None,
past_cache=None,
use_causal_mask=False,
num_layers=12,
depth=1.,
head_mask=None):
assert len(src_ids.shape
) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (
repr(src_ids.shape))
assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None'
d_batch = L.shape(src_ids)[0]
d_seqlen = L.shape(src_ids)[1]
if pos_ids is None:
pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1])
pos_ids = L.cast(pos_ids, 'int64')
if attn_bias is None:
if input_mask is None:
input_mask = L.cast(src_ids != 0, 'float32')
assert len(input_mask.shape) == 2
input_mask = L.unsqueeze(input_mask, axes=[-1])
attn_bias = L.matmul(input_mask, input_mask, transpose_y=True)
if use_causal_mask:
sequence = L.reshape(
L.range(
0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1])
causal_mask = L.cast(
(L.matmul(
sequence, 1. / sequence, transpose_y=True) >= 1.),
'float32')
attn_bias *= causal_mask
else:
assert len(
attn_bias.shape
) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape
attn_bias = (1. - attn_bias) * -10000.0
attn_bias = L.unsqueeze(attn_bias, [1])
attn_bias.stop_gradient = True
if sent_ids is None:
sent_ids = L.zeros_like(src_ids)
if head_mask is not None:
if len(head_mask.shape) == 1:
head_mask = L.unsqueeze(
L.unsqueeze(L.unsqueeze(L.unsqueeze(head_mask, 0), 0), -1), -1)
head_mask = L.expand(
head_mask, expand_times=[num_layers, 1, 1, 1, 1])
elif len(head_mask.shape) == 2:
head_mask = L.unsqueeze(
L.unsqueeze(L.unsqueeze(head_mask, 1), -1), -1)
else:
head_mask = [None] * num_layers
src_embedded = self.word_emb(src_ids)
pos_embedded = self.pos_emb(pos_ids)
sent_embedded = self.sent_emb(sent_ids)
embedded = src_embedded + pos_embedded + sent_embedded
embedded = self.dropout(self.ln(embedded))
encoded, hidden_list, cache_list = self.encoder_stack(
embedded,
attn_bias,
past_cache=past_cache,
num_layers=num_layers,
depth_mult=depth,
head_mask=head_mask)
if self.pooler is not None:
pooled = self.pooler(encoded[:, 0, :])
else:
pooled = None
additional_info = {
'hiddens': hidden_list,
'caches': cache_list,
}
if self.return_additional_info:
return pooled, encoded, additional_info
else:
return pooled, encoded
ErnieModel.forward = _ernie_model_forward
def _seqence_forward(self, *args, **kwargs):
labels = kwargs.pop('labels', None)
pooled, encoded, additional_info = super(
ErnieModelForSequenceClassification, self).forward(*args, **kwargs)
hidden = self.dropout(pooled)
logits = self.classifier(hidden)
if labels is not None:
if len(labels.shape) == 1:
labels = L.reshape(labels, [-1, 1])
loss = L.softmax_with_cross_entropy(logits, labels)
loss = L.reduce_mean(loss)
else:
loss = None
return loss, logits, additional_info
ErnieModelForSequenceClassification.forward = _seqence_forward
def get_config(pretrain_dir_or_url):
bce = 'https://ernie-github.cdn.bcebos.com/'
resource_map = {
'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz',
'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz',
'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
}
url = resource_map[pretrain_dir_or_url]
pretrain_dir = _fetch_from_remote(url, False)
config_path = os.path.join(pretrain_dir, 'ernie_config.json')
if not os.path.exists(config_path):
raise ValueError('config path not found: %s' % config_path)
cfg_dict = dict(json.loads(open(config_path).read()))
return cfg_dict
......@@ -925,7 +925,6 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm):
mean_out = mean
variance_out = variance
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_layout,
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部