未验证 提交 032d5860 编写于 作者: C ceci3 提交者: GitHub

OFA demo for ernie (#493)

* add ofa_ernie
* add ofa_bert
* add bert and update ernie
上级 e9b3a650
# OFA压缩PaddleNLP-BERT模型
BERT-base模型是一个迁移能力很强的通用语义表示模型,但是模型中也有一些参数冗余。本教程将介绍如何使用PaddleSlim对[PaddleNLP](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/)中BERT-base模型进行压缩。
## 1. 压缩结果
利用`bert-base-uncased`模型首先在GLUE数据集上进行finetune,得到需要压缩的模型,之后基于此模型进行压缩。压缩后模型参数大小减小26%(从110M减少到81M),压缩后模型在GLUE dev数据集上的精度和压缩前模型在GLUE dev数据集上的精度对比如下表所示:
| Task | Metric | Baseline | Result with PaddleSlim |
|:-----:|:----------------------------:|:-----------------:|:----------------------:|
| SST-2 | Accuracy | 0.93005 | 0.931193 |
| QNLI | Accuracy | 0.91781 | 0.920740 |
| CoLA | Mattehew's corr | 0.59557 | 0.601244 |
| MRPC | F1/Accuracy | 0.91667/0.88235 | 0.91740/0.88480 |
| STS-B | Person/Spearman corr | 0.88847/0.88350 | 0.89271/0.88958 |
| QQP | Accuracy/F1 | 0.90581/0.87347 | 0.90994/0.87947 |
| MNLI | Matched acc/MisMatched acc | 0.84422/0.84825 | 0.84687/0.85242 |
| RTE | Accuracy | 0.711191 | 0.718412 |
<p align="center">
<strong>表1-1: GLUE数据集精度对比</strong>
</p>
压缩前后模型的耗时如下表所示:
<table style="width:100%;" cellpadding="2" cellspacing="0" border="1" bordercolor="#000000">
<tbody>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Device</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">Batch Size</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">Model</span>
</td>
<td style="text-align:center">
<span style="font-size:18px;">TRT(FP16)</span>
</td>
<td style="text-align:center;">
<span style="font-size:18px;">Latency(ms)</span>
</td>
</tr>
<tr>
<td rowspan=4 align=center> T4 </td>
<td rowspan=4 align=center> 16 </td>
<td rowspan=2 align=center> BERT </td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">110.71</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px">Y</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">22.0</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center>Compressed BERT </td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">69.62</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px">Y</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">14.93</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center> V100 </td>
<td rowspan=2 align=center> 16 </td>
<td style="text-align:center">
<span style="font-size:18px;">BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">33.28</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Compressed BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">21.83</span>
</td>
</tr>
<tr>
<td rowspan=2 align=center> Intel(R) Xeon(R) Gold 5117 CPU @ 2.00GHz </td>
<td rowspan=2 align=center> 16 </td>
<td style="text-align:center">
<span style="font-size:18px;">BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">10831.73</span>
</td>
</tr>
<tr>
<td style="text-align:center">
<span style="font-size:18px;">Compressed BERT</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">N</span>
</td>
<td style="text-align:center">
<span style="font-size:18px">7682.93</span>
</td>
</tr>
</tbody>
</table>
<br />
<p align="center">
<strong>表1-2: 模型速度对比</strong>
</p>
压缩后模型在T4机器上相比原始模型在FP32的情况下加速59%,在TensorRT FP16的情况下加速47.3%。
压缩后模型在V100机器上相比原始模型在FP32的情况下加速52.5%。
压缩后模型在Intel(R) Xeon(R) Gold 5117 CPU上相比原始模型在FP32的情况下加速41%。
## 2. 快速开始
本教程示例以GLUE/SST-2 数据集为例。
### 2.1 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP2.0beta及之后版本和Paddle2.0rc1及之后版本。
```shell
pip install paddlenlp
pip install paddlepaddle_gpu>=2.0rc1
```
### 2.2 Fine-tuing
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考[Fine-tuning教程](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/bert)
Fine-tuning 在dev上的结果如压缩结果表1-1『Baseline』那一列所示。
### 2.3 压缩训练
```python
python -u ./run_glue_ofa.py --model_type bert \
--model_name_or_path ${task_pretrained_model_dir} \
--task_name $TASK_NAME --max_seq_length 128 \
--batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 6 \
--logging_steps 10 \
--save_steps 100 \
--output_dir ./tmp/$TASK_NAME \
--n_gpu 1 \
--width_mult_list 1.0 0.8333333333333334 0.6666666666666666 0.5
```
其中参数释义如下:
- `model_type` 指示了模型类型,当前仅支持BERT模型。
- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `task_name` 表示 Fine-tuning 的任务。
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `num_train_epochs` 表示训练轮数。
- `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。
- `output_dir` 表示模型保存路径。
- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。
- `width_mult_list` 表示压缩训练过程中,对每层Transformer Block的宽度选择的范围。
压缩训练之后在dev上的结果如表1-1中『Result with PaddleSlim』列所示,延时情况如表1-2所示。
## 3. OFA接口介绍
TODO
此差异已折叠。
# OFA压缩ERNIE模型
ERNIE是百度开创性提出的基于知识增强的持续学习语义理解框架,该框架将大数据预训练与多源丰富知识相结合,通过持续学习技术,不断吸收海量文本数据中词汇、结构、语义等方面的知识,实现模型效果不断进化。本教程讲介绍如何使用PaddleSlim对[ERNIE](https://github.com/PaddlePaddle/ERNIE)模型进行压缩。
使用本教程压缩算法可以在精度无损的情况下,对原始Tiny-ERNIE模型进行40%的加速。
## 1. 快速开始
本教程以 CLUE/XNLI 数据集为例。
### 1.1 安装依赖项
由于ERNIE repo中动态图模型是基于Paddle 1.8.5版本进行开发的,所以本教程依赖Paddle 1.8.5和Paddle-ERNIE 0.0.4.dev1.
```shell
pip install paddle-ernie==0.0.4.dev1
pip install paddlepaddle_gpu==1.8.5.post97
```
propeller是ERNIE框架中辅助模型训练的高级框架,包含NLP常用的前、后处理流程。你可以通过将ERNIE repo根目录放入PYTHONPATH的方式导入propeller:
```shell
git clone https://github.com/PaddlePaddle/ERNIE
cd ERNIE
export PYTHONPATH=$PWD:$PYTHONPATH
```
### 1.2 Fine-tuning
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。参考[Fine-tuning教程](https://github.com/PaddlePaddle/ERNIE/tree/v2.4.0#%E6%94%AF%E6%8C%81%E7%9A%84nlp%E4%BB%BB%E5%8A%A1)得到Tiny-ERNIE模型在XNLI数据集上的Fine-tuning模型.
### 1.3 压缩训练
```python
python ./ofa_ernie.py \
--from_pretrained ernie-tiny \
--data_dir ./data/xnli \
--width_mult_list 1.0 0.75 0.5 0.25 \
--depth_mult_list 1.0 0.75
```
其中参数释义如下:
- `from_pretrained` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `data_dir` 指明数据保存目录。
- `width_mult_list` 表示压缩训练过程中,对每层Transformer Block的宽度选择的范围。
- `depth_mult_list` 表示压缩训练过程中,模型包含的Transformer Block数量的选择的范围。
## 2. OFA接口介绍
TODO
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
from ernie_supernet.modeling_ernie_supernet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.fluid as F
import paddle.fluid.dygraph as FD
import paddle.fluid.layers as L
def compute_neuron_head_importance(args, model, tokenizer, dev_ds, place,
model_cfg):
n_layers, n_heads = model_cfg['num_hidden_layers'], model_cfg[
'num_attention_heads']
head_importance = L.zeros(shape=[n_layers, n_heads], dtype='float32')
head_mask = L.ones(shape=[n_layers, n_heads], dtype='float32')
head_mask.stop_gradient = False
intermediate_weight = []
intermediate_bias = []
output_weight = []
for name, w in model.named_parameters():
if 'ffn.i' in name:
if len(w.shape) > 1:
intermediate_weight.append(w)
else:
intermediate_bias.append(w)
if 'ffn.o' in name:
if len(w.shape) > 1:
output_weight.append(w)
neuron_importance = []
for w in intermediate_weight:
neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32'))
eval_task_names = ('mnli', 'mnli-mm') if args.task == 'mnli' else (
args.task, )
for eval_task in eval_task_names:
for batch in dev_ds.start(place):
ids, sids, label = batch
loss, _, _ = model(
ids,
sids,
labels=label,
head_mask=head_mask,
num_layers=model_cfg['num_hidden_layers'])
loss.backward()
head_importance += L.abs(FD.to_variable(head_mask.gradient()))
for w1, b1, w2, current_importance in zip(
intermediate_weight, intermediate_bias, output_weight,
neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
return head_importance, neuron_importance
def reorder_neuron_head(model, head_importance, neuron_importance):
# reorder heads and ffn neurons
for layer, current_importance in enumerate(neuron_importance):
# reorder heads
idx = L.argsort(head_importance[layer], descending=True)[-1]
#model.encoder_stack.block[layer].attn.reorder_heads(idx)
reorder_head(model.encoder_stack.block[layer].attn, idx)
# reorder neurons
idx = L.argsort(FD.to_variable(current_importance), descending=True)[-1]
#model.encoder_stack.block[layer].ffn.reorder_neurons(idx)
reorder_neuron(model.encoder_stack.block[layer].ffn, idx)
def reorder_head(layer, idx):
n, a = layer.n_head, layer.d_key
index = L.reshape(
L.index_select(
L.reshape(
L.arange(
0, n * a, dtype='int64'), shape=[n, a]),
idx,
dim=0),
shape=[-1])
def reorder_head_matrix(linearLayer, index, dim=1):
W = L.index_select(linearLayer.weight, index, dim=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = L.assign(linearLayer.bias).detach()
else:
b = L.assign(L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_head_matrix(
layer.q.fn if hasattr(layer.q, 'fn') else layer.q, index)
reorder_head_matrix(
layer.k.fn if hasattr(layer.k, 'fn') else layer.k, index)
reorder_head_matrix(
layer.v.fn if hasattr(layer.v, 'fn') else layer.v, index)
reorder_head_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
def reorder_neuron(layer, index, dim=0):
def reorder_neurons_matrix(linearLayer, index, dim):
W = L.index_select(linearLayer.weight, index, dim=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = L.assign(linearLayer.bias).detach()
else:
b = L.assign(L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_neurons_matrix(
layer.i.fn if hasattr(layer.i, 'fn') else layer.i, index, dim=1)
reorder_neurons_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import sys
import os
import math
import argparse
import json
import logging
import logging
from functools import partial
import six
import paddle.fluid.dygraph as D
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie.file_utils import _fetch_from_remote
from ernie.modeling_ernie import AttentionLayer, ErnieBlock, ErnieModel, ErnieEncoderStack, ErnieModelForSequenceClassification
log = logging.getLogger(__name__)
def append_name(name, postfix):
if name is None:
return None
elif name == '':
return postfix
else:
return '%s_%s' % (name, postfix)
def _attn_forward(self,
queries,
keys,
values,
attn_bias,
past_cache,
head_mask=None):
assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3
q = self.q(queries)
k = self.k(keys)
v = self.v(values)
cache = (k, v)
if past_cache is not None:
cached_k, cached_v = past_cache
k = L.concat([cached_k, k], 1)
v = L.concat([cached_v, v], 1)
if hasattr(self.q, 'fn') and self.q.fn.cur_config['expand_ratio'] != None:
n_head = int(self.n_head * self.q.fn.cur_config['expand_ratio'])
else:
n_head = self.n_head
q = L.transpose(
L.reshape(q, [0, 0, n_head, q.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
k = L.transpose(
L.reshape(k, [0, 0, n_head, k.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
v = L.transpose(
L.reshape(v, [0, 0, n_head, v.shape[-1] // n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
q = L.scale(q, scale=self.d_key**-0.5)
score = L.matmul(q, k, transpose_y=True)
if attn_bias is not None:
score += attn_bias
score = L.softmax(score, use_cudnn=True)
score = self.dropout(score)
if head_mask is not None:
score = score * head_mask
out = L.matmul(score, v)
out = L.transpose(out, [0, 2, 1, 3])
out = L.reshape(out, [0, 0, out.shape[2] * out.shape[3]])
out = self.o(out)
return out, cache
AttentionLayer.forward = _attn_forward
def _ernie_block_stack_forward(self,
inputs,
attn_bias=None,
past_cache=None,
num_layers=12,
depth_mult=1.,
head_mask=None):
if past_cache is not None:
assert isinstance(
past_cache, tuple
), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr(
type(past_cache))
past_cache = list(zip(*past_cache))
else:
past_cache = [None] * len(self.block)
cache_list_k, cache_list_v, hidden_list = [], [], [inputs]
depth = round(num_layers * depth_mult)
kept_layers_index = []
for i in range(1, depth + 1):
kept_layers_index.append(math.floor(i / depth_mult) - 1)
for i in kept_layers_index:
b = self.block[i]
p = past_cache[i]
inputs, cache = b(inputs,
attn_bias=attn_bias,
past_cache=p,
head_mask=head_mask[i])
cache_k, cache_v = cache
cache_list_k.append(cache_k)
cache_list_v.append(cache_v)
hidden_list.append(inputs)
return inputs, hidden_list, (cache_list_k, cache_list_v)
ErnieEncoderStack.forward = _ernie_block_stack_forward
def _ernie_block_forward(self,
inputs,
attn_bias=None,
past_cache=None,
head_mask=None):
attn_out, cache = self.attn(
inputs,
inputs,
inputs,
attn_bias,
past_cache=past_cache,
head_mask=head_mask) #self attn
attn_out = self.dropout(attn_out)
hidden = attn_out + inputs
hidden = self.ln1(hidden) # dropout/ add/ norm
ffn_out = self.ffn(hidden)
ffn_out = self.dropout(ffn_out)
hidden = ffn_out + hidden
hidden = self.ln2(hidden)
return hidden, cache
ErnieBlock.forward = _ernie_block_forward
def _ernie_model_forward(self,
src_ids,
sent_ids=None,
pos_ids=None,
input_mask=None,
attn_bias=None,
past_cache=None,
use_causal_mask=False,
num_layers=12,
depth=1.,
head_mask=None):
assert len(src_ids.shape
) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (
repr(src_ids.shape))
assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None'
d_batch = L.shape(src_ids)[0]
d_seqlen = L.shape(src_ids)[1]
if pos_ids is None:
pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1])
pos_ids = L.cast(pos_ids, 'int64')
if attn_bias is None:
if input_mask is None:
input_mask = L.cast(src_ids != 0, 'float32')
assert len(input_mask.shape) == 2
input_mask = L.unsqueeze(input_mask, axes=[-1])
attn_bias = L.matmul(input_mask, input_mask, transpose_y=True)
if use_causal_mask:
sequence = L.reshape(
L.range(
0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1])
causal_mask = L.cast(
(L.matmul(
sequence, 1. / sequence, transpose_y=True) >= 1.),
'float32')
attn_bias *= causal_mask
else:
assert len(
attn_bias.shape
) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape
attn_bias = (1. - attn_bias) * -10000.0
attn_bias = L.unsqueeze(attn_bias, [1])
attn_bias.stop_gradient = True
if sent_ids is None:
sent_ids = L.zeros_like(src_ids)
if head_mask is not None:
if len(head_mask.shape) == 1:
head_mask = L.unsqueeze(
L.unsqueeze(L.unsqueeze(L.unsqueeze(head_mask, 0), 0), -1), -1)
head_mask = L.expand(
head_mask, expand_times=[num_layers, 1, 1, 1, 1])
elif len(head_mask.shape) == 2:
head_mask = L.unsqueeze(
L.unsqueeze(L.unsqueeze(head_mask, 1), -1), -1)
else:
head_mask = [None] * num_layers
src_embedded = self.word_emb(src_ids)
pos_embedded = self.pos_emb(pos_ids)
sent_embedded = self.sent_emb(sent_ids)
embedded = src_embedded + pos_embedded + sent_embedded
embedded = self.dropout(self.ln(embedded))
encoded, hidden_list, cache_list = self.encoder_stack(
embedded,
attn_bias,
past_cache=past_cache,
num_layers=num_layers,
depth_mult=depth,
head_mask=head_mask)
if self.pooler is not None:
pooled = self.pooler(encoded[:, 0, :])
else:
pooled = None
additional_info = {
'hiddens': hidden_list,
'caches': cache_list,
}
if self.return_additional_info:
return pooled, encoded, additional_info
else:
return pooled, encoded
ErnieModel.forward = _ernie_model_forward
def _seqence_forward(self, *args, **kwargs):
labels = kwargs.pop('labels', None)
pooled, encoded, additional_info = super(
ErnieModelForSequenceClassification, self).forward(*args, **kwargs)
hidden = self.dropout(pooled)
logits = self.classifier(hidden)
if labels is not None:
if len(labels.shape) == 1:
labels = L.reshape(labels, [-1, 1])
loss = L.softmax_with_cross_entropy(logits, labels)
loss = L.reduce_mean(loss)
else:
loss = None
return loss, logits, additional_info
ErnieModelForSequenceClassification.forward = _seqence_forward
def get_config(pretrain_dir_or_url):
bce = 'https://ernie-github.cdn.bcebos.com/'
resource_map = {
'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz',
'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz',
'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
}
url = resource_map[pretrain_dir_or_url]
pretrain_dir = _fetch_from_remote(url, False)
config_path = os.path.join(pretrain_dir, 'ernie_config.json')
if not os.path.exists(config_path):
raise ValueError('config path not found: %s' % config_path)
cfg_dict = dict(json.loads(open(config_path).read()))
return cfg_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
class AdamW(F.optimizer.AdamOptimizer):
"""AdamW object for dygraph"""
def __init__(self, *args, **kwargs):
weight_decay = kwargs.pop('weight_decay', None)
var_name_to_exclude = kwargs.pop(
'var_name_to_exclude', '.*layer_norm_scale|.*layer_norm_bias|.*b_0')
super(AdamW, self).__init__(*args, **kwargs)
self.wd = weight_decay
self.pat = re.compile(var_name_to_exclude)
def apply_optimize(self, loss, startup_program, params_grads):
super(AdamW, self).apply_optimize(loss, startup_program, params_grads)
for p, g in params_grads:
if not self.pat.match(p.name):
with D.no_grad():
L.assign(p * (20 - self.wd * self.current_step_lr()), p)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
import json
from random import random
from tqdm import tqdm
from functools import reduce, partial
import numpy as np
import math
import logging
import argparse
import paddle
import paddle.fluid as F
import paddle.fluid.dygraph as FD
import paddle.fluid.layers as L
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig, utils
from propeller import log
import propeller.paddle as propeller
from ernie.modeling_ernie import ErnieModelForSequenceClassification
from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer
from ernie.optimization import LinearDecay
from ernie_supernet.importance import compute_neuron_head_importance, reorder_neuron_head
from ernie_supernet.optimization import AdamW
from ernie_supernet.modeling_ernie_supernet import get_config
from paddleslim.nas.ofa.convert_super import Convert, supernet
def soft_cross_entropy(inp, target):
inp_likelihood = L.log_softmax(inp, axis=-1)
target_prob = L.softmax(target, axis=-1)
return -1. * L.mean(L.reduce_sum(inp_likelihood * target_prob, dim=-1))
### get certain config
def apply_config(model, width_mult, depth_mult):
new_config = dict()
def fix_exp(idx):
if (idx - 3) % 6 == 0 or (idx - 5) % 6 == 0:
return True
return False
for idx, (block_k, block_v) in enumerate(model.layers.items()):
if isinstance(block_v, dict) and len(block_v.keys()) != 0:
name, name_idx = block_k.split('_'), int(block_k.split('_')[1])
if fix_exp(name_idx) or 'emb' in block_k or idx == (
len(model.layers.items()) - 2):
block_v['expand_ratio'] = 1.0
else:
block_v['expand_ratio'] = width_mult
if block_k == 'depth':
block_v = depth_mult
new_config[block_k] = block_v
return new_config
if __name__ == '__main__':
parser = argparse.ArgumentParser('classify model with ERNIE')
parser.add_argument(
'--from_pretrained',
type=str,
required=True,
help='pretrained model directory or tag')
parser.add_argument(
'--max_seqlen',
type=int,
default=128,
help='max sentence length, should not greater than 512')
parser.add_argument('--bsz', type=int, default=32, help='batchsize')
parser.add_argument('--epoch', type=int, default=3, help='epoch')
parser.add_argument(
'--data_dir',
type=str,
required=True,
help='data directory includes train / develop data')
parser.add_argument('--task', type=str, default='mnli', help='task name')
parser.add_argument(
'--use_lr_decay',
action='store_true',
help='if set, learning rate will decay to zero at `max_steps`')
parser.add_argument(
'--warmup_proportion',
type=float,
default=0.1,
help='if use_lr_decay is set, '
'learning rate will raise to `lr` at `warmup_proportion` * `max_steps` and decay to 0. at `max_steps`'
)
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument(
'--inference_model_dir',
type=str,
default='ofa_ernie_inf',
help='inference model output directory')
parser.add_argument(
'--save_dir',
type=str,
default='ofa_ernie_save',
help='model output directory')
parser.add_argument(
'--max_steps',
type=int,
default=None,
help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE')
parser.add_argument(
'--wd',
type=float,
default=0.01,
help='weight decay, aka L2 regularizer')
parser.add_argument(
'--width_lambda1',
type=float,
default=1.0,
help='scale for logit loss in elastic width')
parser.add_argument(
'--width_lambda2',
type=float,
default=0.1,
help='scale for rep loss in elastic width')
parser.add_argument(
'--depth_lambda1',
type=float,
default=1.0,
help='scale for logit loss in elastic depth')
parser.add_argument(
'--depth_lambda2',
type=float,
default=1.0,
help='scale for rep loss in elastic depth')
parser.add_argument(
'--reorder_weight',
action='store_false',
help='Whether to reorder weight')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='checkpoint to warm start from')
parser.add_argument(
'--width_mult_list',
nargs='+',
type=float,
default=[1.0, 0.75, 0.5, 0.5],
help="width mult in compress")
parser.add_argument(
'--depth_mult_list',
nargs='+',
type=float,
default=[1.0, 2 / 3],
help="depth mult in compress")
args = parser.parse_args()
if args.task == 'sts-b':
mode = 'regression'
else:
mode = 'classification'
tokenizer = ErnieTinyTokenizer.from_pretrained(args.from_pretrained)
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn(
'seg_a',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.TextColumn(
'seg_b',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.LabelColumn(
'label',
vocab_dict={
b"contradictory": 0,
b"contradiction": 0,
b"entailment": 1,
b"neutral": 2,
}),
])
def map_fn(seg_a, seg_b, label):
seg_a, seg_b = tokenizer.truncate(seg_a, seg_b, seqlen=args.max_seqlen)
sentence, segments = tokenizer.build_for_ernie(seg_a, seg_b)
return sentence, segments, label
train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
shapes = ([-1, args.max_seqlen], [-1, args.max_seqlen], [-1])
types = ('int64', 'int64', 'int64')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
place = F.CUDAPlace(0)
with FD.guard(place):
model = ErnieModelForSequenceClassification.from_pretrained(
args.from_pretrained, num_labels=3, name='')
setattr(model, 'return_additional_info', True)
origin_weights = {}
for name, param in model.named_parameters():
origin_weights[name] = param
sp_config = supernet(expand_ratio=args.width_mult_list)
model = Convert(sp_config).convert(model)
utils.set_state_dict(model, origin_weights)
del origin_weights
teacher_model = ErnieModelForSequenceClassification.from_pretrained(
args.from_pretrained, num_labels=3, name='teacher')
setattr(teacher_model, 'return_additional_info', True)
default_run_config = {
'n_epochs': [[4 * args.epoch], [6 * args.epoch]],
'init_learning_rate': [[args.lr], [args.lr]],
'elastic_depth': args.depth_mult_list,
'dynamic_batch_size': [[1, 1], [1, 1]]
}
run_config = RunConfig(**default_run_config)
model_cfg = get_config(args.from_pretrained)
default_distill_config = {'teacher_model': teacher_model}
distill_config = DistillConfig(**default_distill_config)
ofa_model = OFA(model,
run_config,
distill_config=distill_config,
elastic_order=['width', 'depth'])
### suppose elastic width first
if args.reorder_weight:
head_importance, neuron_importance = compute_neuron_head_importance(
args, ofa_model.model, tokenizer, dev_ds, place, model_cfg)
reorder_neuron_head(ofa_model.model, head_importance,
neuron_importance)
#################
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = FD.load_dygraph(args.init_checkpoint)
ofa_model.model.set_dict(sd)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
if args.use_lr_decay:
opt = AdamW(
learning_rate=LinearDecay(args.lr,
int(args.warmup_proportion *
args.max_steps), args.max_steps),
parameter_list=ofa_model.model.parameters(),
weight_decay=args.wd,
grad_clip=g_clip)
else:
opt = AdamW(
args.lr,
parameter_list=ofa_model.model.parameters(),
weight_decay=args.wd,
grad_clip=g_clip)
for epoch in range(max(run_config.n_epochs[-1])):
ofa_model.set_epoch(epoch)
if epoch <= int(max(run_config.n_epochs[0])):
ofa_model.set_task('width')
depth_mult_list = [1.0]
else:
ofa_model.set_task('depth')
depth_mult_list = run_config.elastic_depth
for step, d in enumerate(
tqdm(
train_ds.start(place), desc='training')):
ids, sids, label = d
accumulate_gradients = dict()
for param in opt._parameter_list:
accumulate_gradients[param.name] = 0.0
for depth_mult in depth_mult_list:
for width_mult in args.width_mult_list:
net_config = apply_config(
ofa_model, width_mult, depth_mult=depth_mult)
ofa_model.set_net_config(net_config)
student_output, teacher_output = ofa_model(
ids,
sids,
labels=label,
num_layers=model_cfg['num_hidden_layers'])
loss, student_logit, student_reps = student_output[
0], student_output[1], student_output[2]['hiddens']
teacher_logit, teacher_reps = teacher_output[
1], teacher_output[2]['hiddens']
if ofa_model.task == 'depth':
depth_mult = ofa_model.current_config['depth']
depth = round(model_cfg['num_hidden_layers'] *
depth_mult)
kept_layers_index = []
for i in range(1, depth + 1):
kept_layers_index.append(
math.floor(i / depth_mult) - 1)
if mode == 'classification':
logit_loss = soft_cross_entropy(
student_logit, teacher_logit.detach())
else:
logit_loss = 0.0
### hidden_states distillation loss
rep_loss = 0.0
for stu_rep, tea_rep in zip(
student_reps,
list(teacher_reps[i]
for i in kept_layers_index)):
tmp_loss = L.mse_loss(stu_rep, tea_rep.detach())
rep_loss += tmp_loss
loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
else:
### logit distillation loss
if mode == 'classification':
logit_loss = soft_cross_entropy(
student_logit, teacher_logit.detach())
else:
logit_loss = 0.0
### hidden_states distillation loss
rep_loss = 0.0
for stu_rep, tea_rep in zip(student_reps,
teacher_reps):
tmp_loss = L.mse_loss(stu_rep, tea_rep.detach())
rep_loss += tmp_loss
loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
if step % 10 == 0:
print('train loss %.5f lr %.3e' %
(loss.numpy(), opt.current_step_lr()))
loss.backward()
param_grads = opt.backward(loss)
for param in opt._parameter_list:
accumulate_gradients[param.name] += param.gradient()
for k, v in param_grads:
assert k.name in accumulate_gradients.keys(
), "{} not in accumulate_gradients".format(k.name)
v.set_value(accumulate_gradients[k.name])
opt.apply_optimize(
loss, startup_program=None, params_grads=param_grads)
ofa_model.model.clear_gradients()
if step % 100 == 0:
for depth_mult in depth_mult_list:
for width_mult in args.width_mult_list:
net_config = apply_config(
ofa_model, width_mult, depth_mult=depth_mult)
ofa_model.set_net_config(net_config)
acc = []
tea_acc = []
with FD.base._switch_tracer_mode_guard_(
is_train=False):
ofa_model.model.eval()
for step, d in enumerate(
tqdm(
dev_ds.start(place),
desc='evaluating %d' % epoch)):
ids, sids, label = d
[loss, logits,
_], [_, tea_logits, _] = ofa_model(
ids,
sids,
labels=label,
num_layers=model_cfg[
'num_hidden_layers'])
a = L.argmax(logits, -1) == label
acc.append(a.numpy())
ta = L.argmax(tea_logits, -1) == label
tea_acc.append(ta.numpy())
ofa_model.model.train()
print(
'width_mult: %f, depth_mult: %f: acc %.5f, teacher acc %.5f'
% (width_mult, depth_mult,
np.concatenate(acc).mean(),
np.concatenate(tea_acc).mean()))
if args.save_dir is not None:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
F.save_dygraph(ofa_model.model.state_dict(), args.save_dir)
# PaddleNLP-BERT模型压缩教程
1. 对Fine-tuning得到模型通过计算参数及其梯度的乘积得到参数的重要性,把模型参数根据重要性进行重排序。
2. 超网络中最大的子网络选择和Bert-base模型网络结构一致的网络结构,其他小的子网络是对最大网络的进行不同的宽度选择来得到的,宽度选择
具体指的是网络中的参数进行裁剪,所有子网络在整个训练过程中都是参数共享的。
2. 用重排序之后的模型参数作为超网络模型的初始化参数。
3. Fine-tuning之后的模型作为教师网络,超网络作为学生网络,进行知识蒸馏。
<p align="center">
<img src="../../images/algo/ofa_bert.jpg" width="950"/><br />
整体流程图
</p>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册