未验证 提交 d9e9a7b6 编写于 作者: C ceci3 提交者: GitHub

[cherry pick] ofa docs and bug fix (#612)

* cherry pick

* fix bug when paddle upgrade (#606)

* add jpg
上级 e3226b49
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.fluid as F
import paddle.fluid.dygraph as FD
import paddle.fluid.layers as L
def compute_neuron_head_importance(args, model, dev_ds, place, model_cfg):
n_layers, n_heads = model_cfg['num_hidden_layers'], model_cfg[
'num_attention_heads']
head_importance = L.zeros(shape=[n_layers, n_heads], dtype='float32')
head_mask = L.ones(shape=[n_layers, n_heads], dtype='float32')
head_mask.stop_gradient = False
intermediate_weight = []
intermediate_bias = []
output_weight = []
for name, w in model.named_parameters():
if 'ffn.i' in name:
if len(w.shape) > 1:
intermediate_weight.append(w)
else:
intermediate_bias.append(w)
if 'ffn.o' in name:
if len(w.shape) > 1:
output_weight.append(w)
neuron_importance = []
for w in intermediate_weight:
neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32'))
eval_task_names = ('mnli', 'mnli-mm') if args.task == 'mnli' else (
args.task, )
for eval_task in eval_task_names:
for batch in dev_ds.start(place):
ids, sids, label = batch
loss, _, _ = model(
ids,
sids,
labels=label,
head_mask=head_mask,
num_layers=model_cfg['num_hidden_layers'])
loss.backward()
head_importance += L.abs(FD.to_variable(head_mask.gradient()))
for w1, b1, w2, current_importance in zip(
intermediate_weight, intermediate_bias, output_weight,
neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
return head_importance, neuron_importance
def reorder_neuron_head(model, head_importance, neuron_importance):
# reorder heads and ffn neurons
for layer, current_importance in enumerate(neuron_importance):
# reorder heads
idx = L.argsort(head_importance[layer], descending=True)[-1]
#model.encoder_stack.block[layer].attn.reorder_heads(idx)
reorder_head(model.encoder_stack.block[layer].attn, idx)
# reorder neurons
idx = L.argsort(FD.to_variable(current_importance), descending=True)[-1]
#model.encoder_stack.block[layer].ffn.reorder_neurons(idx)
reorder_neuron(model.encoder_stack.block[layer].ffn, idx)
def reorder_head(layer, idx):
n, a = layer.n_head, layer.d_key
index = L.reshape(
L.index_select(
L.reshape(
L.arange(
0, n * a, dtype='int64'), shape=[n, a]),
idx,
dim=0),
shape=[-1])
def reorder_head_matrix(linearLayer, index, dim=1):
W = L.index_select(linearLayer.weight, index, dim=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = L.assign(linearLayer.bias).detach()
else:
b = L.assign(L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_head_matrix(
layer.q.fn if hasattr(layer.q, 'fn') else layer.q, index)
reorder_head_matrix(
layer.k.fn if hasattr(layer.k, 'fn') else layer.k, index)
reorder_head_matrix(
layer.v.fn if hasattr(layer.v, 'fn') else layer.v, index)
reorder_head_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
def reorder_neuron(layer, index, dim=0):
def reorder_neurons_matrix(linearLayer, index, dim):
W = L.index_select(linearLayer.weight, index, dim=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = L.assign(linearLayer.bias).detach()
else:
b = L.assign(L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_neurons_matrix(
layer.i.fn if hasattr(layer.i, 'fn') else layer.i, index, dim=1)
reorder_neurons_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import re
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
class AdamW(F.optimizer.AdamOptimizer):
"""AdamW object for dygraph"""
def __init__(self, *args, **kwargs):
weight_decay = kwargs.pop('weight_decay', None)
var_name_to_exclude = kwargs.pop(
'var_name_to_exclude', '.*layer_norm_scale|.*layer_norm_bias|.*b_0')
super(AdamW, self).__init__(*args, **kwargs)
self.wd = weight_decay
self.pat = re.compile(var_name_to_exclude)
def apply_optimize(self, loss, startup_program, params_grads):
super(AdamW, self).apply_optimize(loss, startup_program, params_grads)
for p, g in params_grads:
if not self.pat.match(p.name):
with D.no_grad():
L.assign(p * (1. - self.wd * self.current_step_lr()), p)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
import json
from random import random
from tqdm import tqdm
from functools import reduce, partial
import numpy as np
import math
import logging
import argparse
import paddle
import paddle.fluid as F
import paddle.fluid.dygraph as FD
import paddle.fluid.layers as L
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig, utils
from propeller import log
import propeller.paddle as propeller
from ernie.modeling_ernie import ErnieModelForSequenceClassification
from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer
from ernie.optimization import LinearDecay
from ernie_supernet.importance import compute_neuron_head_importance, reorder_neuron_head
from ernie_supernet.optimization import AdamW
from ernie_supernet.modeling_ernie_supernet import get_config
from paddleslim.nas.ofa.convert_super import Convert, supernet
def soft_cross_entropy(inp, target):
inp_likelihood = L.log_softmax(inp, axis=-1)
target_prob = L.softmax(target, axis=-1)
return -1. * L.mean(L.reduce_sum(inp_likelihood * target_prob, dim=-1))
if __name__ == '__main__':
parser = argparse.ArgumentParser('classify model with ERNIE')
parser.add_argument(
'--from_pretrained',
type=str,
required=True,
help='pretrained model directory or tag')
parser.add_argument(
'--max_seqlen',
type=int,
default=128,
help='max sentence length, should not greater than 512')
parser.add_argument('--bsz', type=int, default=32, help='batchsize')
parser.add_argument('--epoch', type=int, default=3, help='epoch')
parser.add_argument(
'--data_dir',
type=str,
required=True,
help='data directory includes train / develop data')
parser.add_argument('--task', type=str, default='xnli', help='task name')
parser.add_argument(
'--use_lr_decay',
action='store_true',
help='if set, learning rate will decay to zero at `max_steps`')
parser.add_argument(
'--warmup_proportion',
type=float,
default=0.1,
help='if use_lr_decay is set, '
'learning rate will raise to `lr` at `warmup_proportion` * `max_steps` and decay to 0. at `max_steps`'
)
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument(
'--inference_model_dir',
type=str,
default='ofa_ernie_inf',
help='inference model output directory')
parser.add_argument(
'--save_dir',
type=str,
default='ofa_ernie_save',
help='model output directory')
parser.add_argument(
'--max_steps',
type=int,
default=None,
help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE')
parser.add_argument(
'--wd',
type=float,
default=0.01,
help='weight decay, aka L2 regularizer')
parser.add_argument(
'--width_lambda1',
type=float,
default=1.0,
help='scale for logit loss in elastic width')
parser.add_argument(
'--width_lambda2',
type=float,
default=0.1,
help='scale for rep loss in elastic width')
parser.add_argument(
'--depth_lambda1',
type=float,
default=1.0,
help='scale for logit loss in elastic depth')
parser.add_argument(
'--depth_lambda2',
type=float,
default=1.0,
help='scale for rep loss in elastic depth')
parser.add_argument(
'--reorder_weight',
action='store_false',
help='Whether to reorder weight')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='checkpoint to warm start from')
parser.add_argument(
'--width_mult_list',
nargs='+',
type=float,
default=[1.0, 0.75, 0.5, 0.25],
help="width mult in compress")
parser.add_argument(
'--depth_mult_list',
nargs='+',
type=float,
default=[1.0, 2 / 3],
help="depth mult in compress")
args = parser.parse_args()
if args.task == 'sts-b':
mode = 'regression'
else:
mode = 'classification'
tokenizer = ErnieTinyTokenizer.from_pretrained(args.from_pretrained)
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn(
'seg_a',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.TextColumn(
'seg_b',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.LabelColumn(
'label',
vocab_dict={
b"contradictory": 0,
b"contradiction": 0,
b"entailment": 1,
b"neutral": 2,
}),
])
def map_fn(seg_a, seg_b, label):
seg_a, seg_b = tokenizer.truncate(seg_a, seg_b, seqlen=args.max_seqlen)
sentence, segments = tokenizer.build_for_ernie(seg_a, seg_b)
return sentence, segments, label
train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
shapes = ([-1, args.max_seqlen], [-1, args.max_seqlen], [-1])
types = ('int64', 'int64', 'int64')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
place = F.CUDAPlace(0)
with FD.guard(place):
model = ErnieModelForSequenceClassification.from_pretrained(
args.from_pretrained, num_labels=3, name='')
setattr(model, 'return_additional_info', True)
origin_weights = {}
for name, param in model.named_parameters():
origin_weights[name] = param
sp_config = supernet(expand_ratio=args.width_mult_list)
model = Convert(sp_config).convert(model)
utils.set_state_dict(model, origin_weights)
del origin_weights
teacher_model = ErnieModelForSequenceClassification.from_pretrained(
args.from_pretrained, num_labels=3, name='teacher')
setattr(teacher_model, 'return_additional_info', True)
default_run_config = {
'n_epochs': [[4 * args.epoch], [6 * args.epoch]],
'init_learning_rate': [[args.lr], [args.lr]],
'elastic_depth': args.depth_mult_list,
'dynamic_batch_size': [[1, 1], [1, 1]]
}
run_config = RunConfig(**default_run_config)
model_cfg = get_config(args.from_pretrained)
default_distill_config = {'teacher_model': teacher_model}
distill_config = DistillConfig(**default_distill_config)
ofa_model = OFA(model,
run_config,
distill_config=distill_config,
elastic_order=['width', 'depth'])
### suppose elastic width first
if args.reorder_weight:
head_importance, neuron_importance = compute_neuron_head_importance(
args, ofa_model.model, dev_ds, place, model_cfg)
reorder_neuron_head(ofa_model.model, head_importance,
neuron_importance)
#################
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = FD.load_dygraph(args.init_checkpoint)
ofa_model.model.set_dict(sd)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
if args.use_lr_decay:
opt = AdamW(
learning_rate=LinearDecay(args.lr,
int(args.warmup_proportion *
args.max_steps), args.max_steps),
parameter_list=ofa_model.model.parameters(),
weight_decay=args.wd,
grad_clip=g_clip)
else:
opt = AdamW(
args.lr,
parameter_list=ofa_model.model.parameters(),
weight_decay=args.wd,
grad_clip=g_clip)
for epoch in range(max(run_config.n_epochs[-1])):
ofa_model.set_epoch(epoch)
if epoch <= int(max(run_config.n_epochs[0])):
ofa_model.set_task('width')
depth_mult_list = [1.0]
else:
ofa_model.set_task('depth')
depth_mult_list = run_config.elastic_depth
for step, d in enumerate(
tqdm(
train_ds.start(place), desc='training')):
ids, sids, label = d
accumulate_gradients = dict()
for param in opt._parameter_list:
accumulate_gradients[param.name] = 0.0
for depth_mult in depth_mult_list:
for width_mult in args.width_mult_list:
net_config = utils.dynabert_config(
ofa_model, width_mult, depth_mult=depth_mult)
ofa_model.set_net_config(net_config)
student_output, teacher_output = ofa_model(
ids,
sids,
labels=label,
num_layers=model_cfg['num_hidden_layers'])
loss, student_logit, student_reps = student_output[
0], student_output[1], student_output[2]['hiddens']
teacher_logit, teacher_reps = teacher_output[
1], teacher_output[2]['hiddens']
if ofa_model.task == 'depth':
depth_mult = ofa_model.current_config['depth']
depth = round(model_cfg['num_hidden_layers'] *
depth_mult)
kept_layers_index = []
for i in range(1, depth + 1):
kept_layers_index.append(
math.floor(i / depth_mult) - 1)
if mode == 'classification':
logit_loss = soft_cross_entropy(
student_logit, teacher_logit.detach())
else:
logit_loss = 0.0
### hidden_states distillation loss
rep_loss = 0.0
for stu_rep, tea_rep in zip(
student_reps,
list(teacher_reps[i]
for i in kept_layers_index)):
tmp_loss = L.mse_loss(stu_rep, tea_rep.detach())
rep_loss += tmp_loss
loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
else:
### logit distillation loss
if mode == 'classification':
logit_loss = soft_cross_entropy(
student_logit, teacher_logit.detach())
else:
logit_loss = 0.0
### hidden_states distillation loss
rep_loss = 0.0
for stu_rep, tea_rep in zip(student_reps,
teacher_reps):
tmp_loss = L.mse_loss(stu_rep, tea_rep.detach())
rep_loss += tmp_loss
loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
if step % 10 == 0:
print('train loss %.5f lr %.3e' %
(loss.numpy(), opt.current_step_lr()))
loss.backward()
param_grads = opt.backward(loss)
for param in opt._parameter_list:
accumulate_gradients[param.name] += param.gradient()
for k, v in param_grads:
assert k.name in accumulate_gradients.keys(
), "{} not in accumulate_gradients".format(k.name)
v.set_value(accumulate_gradients[k.name])
opt.apply_optimize(
loss, startup_program=None, params_grads=param_grads)
ofa_model.model.clear_gradients()
if step % 100 == 0:
for depth_mult in depth_mult_list:
for width_mult in args.width_mult_list:
net_config = utils.dynabert_config(
ofa_model, width_mult, depth_mult=depth_mult)
ofa_model.set_net_config(net_config)
acc = []
tea_acc = []
with FD.base._switch_tracer_mode_guard_(
is_train=False):
ofa_model.model.eval()
for step, d in enumerate(
tqdm(
dev_ds.start(place),
desc='evaluating %d' % epoch)):
ids, sids, label = d
[loss, logits,
_], [_, tea_logits, _] = ofa_model(
ids,
sids,
labels=label,
num_layers=model_cfg[
'num_hidden_layers'])
a = L.argmax(logits, -1) == label
acc.append(a.numpy())
ta = L.argmax(tea_logits, -1) == label
tea_acc.append(ta.numpy())
ofa_model.model.train()
print(
'width_mult: %f, depth_mult: %f: acc %.5f, teacher acc %.5f'
% (width_mult, depth_mult,
np.concatenate(acc).mean(),
np.concatenate(tea_acc).mean()))
if args.save_dir is not None:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
F.save_dygraph(ofa_model.model.state_dict(), args.save_dir)
Convert SuperNet
============
在进行Once-For-All训练之前,需要把普通的模型先转换为由动态OP组网的超网络。超网络转换在把普通网络转换为超网络的同时也会把超网络中的最大的子网络转换为搜索空间中最大的网络。
.. note::
- 如果原始卷积的kernel_size1,则不会对它的kernel_size进行改变。
..
接口介绍
------------------
.. py:class:: paddleslim.nas.ofa.supernet(kernel_size=None, expand_ratio=None, channel=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/nas/ofa/convert_super.py#L643>`_
通过键值对的方式传入搜索空间。
**参数:**
- **kernel_size(list|tuple, optional)** 网络中Conv2Dkernel_size的搜索空间。
- **expand_ratio(list|tuple, optional)** 网络中Conv2D的通道数、EmbeddingLinear的参数输出维度的搜索空间,本参数是按照原始模型中每个OP的通道的比例来得到转换后的超网络中每个OP的通道数,所以本参数的长度为1。本参数和 ``channel`` 之间设置一个即可。
- **channel(list(list)|tuple(tuple), optional)** 网络中Conv2D的通道数、EmbeddingLinear的参数输出维度的搜索空间,本参数是直接设置超网络中每个OP的通道数量,所以本参数的长度需要和网络中包括的Conv2DEmbeddingLinear的总数相等。本参数和 ``expand_ratio`` 之间设置一个即可。
**返回:**
超网络配置。
.. py:class:: paddleslim.nas.ofa.Convert(context)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/nas/ofa/convert_super.py#L45>`_
把普通网络根据传入的自定义的搜索空间转换为超网络。
**返回:**
转换实例
**参数:**
- **context(paddleslim.nas.ofa.supernet)** 用户自定义的搜索空间
.. py:method:: convert(network)
实际超网络转换。
**参数:**
- **network(paddle.nn.Layer)** 要转换为超网络的原始模型实例。
**返回:**
实例化之后的超网络。
PaddleSlim提供了三种方式构造超网络,下面分别介绍这三种方式。
方式一
------------------
直接调用搜索空间定义接口和超网络转换接口转换超网络。这种方式的优点是不需要重新定义网络,直接对初始化之后的网络实例进行转换,缺点是只能对整个网络进行超网络转换,不能对部分网络进行超网络转换。
**示例代码:**
.. code-block:: python
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(self.model)
方式二
------------------
使用上下文的方式转换超网络。这种方式的优点是可以仅转换部分网络为超网络,或者对网络不同部分进行不同的超网络转换,缺点是需要拿到原始网络的定义,并修改网络定义。
**示例代码:**
.. code-block:: python
import paddle.nn as nn
from paddleslim.nas.ofa.convert_super import supernet
class Net(nn.Layer):
def __init__(self):
super(Net, self).__init__()
models = []
with supernet(kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
models += [nn.Conv2D(3, 4, 3, padding=1)]
models += [nn.InstanceNorm2D(4)]
models = ofa_super.convert(models)
models += [nn.Conv2D(4, 4, 3, groups=4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
方式三
------------------
直接调用动态OP组网,组网方式和普通模型相同。PaddleSlim支持的动态OP请参考 `动态OP <>`_ 。这种方式的优点是组网更自由,缺点是用法更复杂。
.. note::
- paddleslim.nas.ofa.layers 文件中的动态OP是基于Paddle 2.0beta及其之后的版本实现的。paddleslim.nas.ofa.layers_old文件中的动态OP是基于Paddle 2.0beta之前的版本实现的。
- Block接口是把当前动态OP的搜索空间加入到OFA训练过程中的搜索空间中。由于Conv2DEmbeddingLinear这三个OP的参数中输出的部分是可以随意修改的,所以这三个OP所对应的动态OP需要使用Block包装一下。而Norm相关的动态OP由于其参数大小是根据输入大小相关,所以不需要用Block包装。
..
**示例代码:**
.. code-block:: python
import paddle.nn as nn
from paddleslim.nas.ofa.layers import Block, SuperConv2D, SuperBatchNorm2D
class Net(nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.models = [Block(SuperConv2D(3, 4, 3, candidate_config={'kernel_size': (3, 5, 7), 'channel': (4, 8, 16)}))]
self.models += [SuperBatchNorm2D(16)]
def forward(self, inputs):
return self.models(inputs)
Once-For-All
============
在进行Once-For-All训练之前,需要把普通的模型先转换为由动态OP组网的超网络。超网络转换方式可以参考 `超网络转换 <>`_ 。
Once-For-All 训练参数配置
------------------
RunConfig
>>>>>>>>>
超网络实际运行需要用到的配置和超参,通过字典的形式配置。如果想使用论文中默认的 ``Progressive shrinking`` 的方式进行超网络训练,则本项为必填参数。否则可以通过 ``paddleslim.nas.ofa.OFA().set_epoch(epoch)`` 和 ``paddleslim.nas.ofa.OFA().set_task(task, phase=None)`` 来手动指定超网络训练所处的阶段。默认:None。
**参数:**
- **train_batch_size:(int, 可选):** 训练时的batch size,用来计算每个epoch包括的iteration数量。默认:None。
- **n_epochs(list, 可选):** 包含每个阶段运行到多少epochs,用来判断当前epoch在超网训练中所处的阶段,默认:None。
- **total_images(int, 可选):** 训练集图片数量,用来计算每个epoch包括的iteration数量。默认:None。
- **elastic_depth(list/tuple, 可选):** 如果设置为None,则不把depth作为搜索的一部分,否则,采样到的config中会包含depth。对模型depth的改变需要在模型定义中的forward部分配合使用,具体示例可以参考 `示例 <>`_ ,默认:None。
- **dynamic_batch_size(list, 可选):** 代表每个阶段每个batch数据应该参与几个子网络的训练,shape应该和n_epochs的shape保持一致。默认:None。
**返回:**
训练配置。
**示例代码:**
.. code-block:: python
from paddleslim.nas.ofa import RunConfig
default_run_config = {
'train_batch_size': 1,
'n_epochs': [[1], [2, 3], [4, 5]],
'total_images': 12,
'elastic_depth': (5, 15, 24)
'dynamic_batch_size': [1, 1, 1],
}
run_config = RunConfig(**default_run_config)
DistillConfig
>>>>>>>>>
如果在训练过程中需要添加蒸馏的话,蒸馏过程的配置和超参,通过字典的形式配置,默认:None。
**参数:**
- **lambda_distill(float, 可选):** 蒸馏loss的缩放比例,默认:None。
- **teacher_model(instance of paddle.nn.Layer, 可选):** 教师网络实例,默认:None。
- **mapping_layers(list[str], 可选):** 如果需要给模型中间层添加蒸馏,则需要用这个参数给出需要添加蒸馏的中间层的名字,默认:None。
- **teacher_model_path(str, 可选):** 教师网络预训练模型的路径,默认:None。
- **distill_fn(instance of paddle.nn.Layer, 可选):** 如果需要自定义添加蒸馏loss,则需要传入loss的实例,若传入参数为None,则默认使用mse_loss作为蒸馏损失,默认:None。
- **mapping_op(str, 可选):** 如果在给模型中间层添加蒸馏的时候教师网络和学生网络中间层的shape不相同,则给学生网络中间层添加相应的op,保证在计算蒸馏损失时,教师网络和学生网络中间层的shape相同。该参数可选范围为 ``["conv", "linear", None]`` ,'conv'表示添加Conv2D,'linear'表示添加Linear,None表示不添加任何op。若使用本参数在蒸馏过程中额外添加op,则在优化过程中可以调用 ``paddleslim.nas.ofa.OFA().netAs_param`` 获取到这些op的参数,并把这些op的参数添加到优化器的参数列表中。默认:None。
**返回:**
蒸馏配置。
**示例代码:**
.. code-block:: python
from paddleslim.nas.ofa import DistillConfig
default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': teacher_model,
'mapping_layers': ['models.0.fn'],
'teacher_model_path': None,
'distill_fn': None,
'mapping_op': 'conv2d'
}
distill_config = DistillConfig(**default_distill_config)
OFA
------------------
把超网络训练方式转换为Once-For-All的方式训练。在 `Once-For-All论文 <>`_ 中,提出 ``Progressive Shrinking`` 的超网络训练方式,具体原理是在训练过程中按照 ``elastic kernel_size`` 、 ``elastic width`` 、 ``elactic depth`` 的顺序分阶段进行训练,并且在训练过程中逐步扩大搜索空间,例如:搜索空间为 ``kernel_size=(3,5,7), expand_ratio=(0.5, 1.0, 2.0), depth=(0.5, 0.75, 1.0)`` ,则在训练过程中首先对kernel size的大小进行动态训练,并把kernel_size的动态训练分为两个阶段,第一阶段kernel_size的搜索空间为 ``[5, 7]`` ,第二阶段kernel_size的搜索空间为 ``[3, 5, 7]`` ;之后把expand_ratio的动态训练加入到超网络训练中,和对kernel_size的训练方式相同,对expand_ratio的动态训练也分为两个阶段,第一阶段expand_ratio的搜索空间为 ``[1.0, 2.0]`` ,第二阶段expand_ratio的搜索空间为 ``[0.5, 1.0, 2.0]`` ;最后对depth进行动态训练,训练阶段和kernel_size相同。
.. py:class:: paddleslim.nas.ofa.OFA(model, run_config=None, distill_config=None, elastic_order=None, train_full=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/nas/ofa/ofa.py#L91>`_
**参数:**
- **model(paddle.nn.Layer):** 把超网络的训练规则转换成默认的Once-For-All论文中推荐的方式训练。
- **run_config(paddleslim.ofa.RunConfig, 可选):** 模型运行过程中的配置,默认:None。
- **distill_config(paddleslim.ofa.DistillConfig, 可选):** 若模型运行过程中添加蒸馏的话,蒸馏相关的配置,具体可配置的参数请参考 `DistillConfig <>`_ , 为None的话则不添加蒸馏,默认:None。
- **elastic_order(list, 可选):** 指定训练顺序,若传入None,则按照默认的 ``Progressive Shrinking`` 的方式进行超网络训练,默认:None。
- **train_full(bool, 可选):** 是否训练超网络中最大的子网络,默认:False。
**返回:**
OFA实例
**示例代码:**
.. code-block:: python
from paddlslim.nas.ofa import OFA
ofa_model = OFA(model)
..
.. py:method:: set_epoch(epoch)
手动设置OFA训练所处的epoch。
**参数:**
- **epoch(int):** - 模型训练过程中当前所处的epoch。
**返回:**
None
**示例代码:**
.. code-block:: python
ofa_model.set_epoch(3)
.. py:method:: set_task(task, phase=None)
手动设置OFA超网络训练所处的阶段。
**参数:**
- **task(list(str)|str):** 手动设置超网络训练中当前训练的任务名称,可选 ``"kernel_size", "width", "depth"`` 。
- **phase(int, 可选):** 手动设置超网络训练中当前训练任务所处的阶段,阶段指的是 ``Progresssive Shrinking`` 训练方式中每个任务依次增加搜索空间,不同阶段代表着不同大小的搜索空间,若为None,则当前任务使用整个搜索空间,默认:None。
**返回:**
None
**示例代码:**
.. code-block:: python
ofa_model.set_task('width')
.. py:method:: set_net_config(config)
手动指定训练超网络中的指定配置的子网络,在训练超网络中特定的某一个或几个子网络时使用。
**参数:**
- **config(dict):** 某个子网络训练中每层的训练配置。
**返回:**
None
**示例代码:**
.. code-block:: python
config = ofa_model.current_config
ofa_model.set_net_config(config)
.. py:method:: calc_distill_loss()
若OFA训练过程中包含中间层蒸馏,则需要调用本接口获取中间蒸馏损失。
**返回:**
中间层蒸馏损失。
**示例代码:**
.. code-block:: python
distill_loss = ofa_model.calc_distill_loss()
.. py:method:: search()
### TODO
.. py:method:: export(config)
根据传入的子网络配置导出当前子网络的参数。
**参数:**
- **config(dict):** 某个子网络每层的配置。
**返回:**
TODO
**示例代码:**
TODO
此差异已折叠。
# TinyERNIE模型压缩教程
1. 本教程是对TinyERNIE模型进行压缩的原理介绍。并以ERNIE repo中TinyERNIE模型为例,说明如何快速把整体压缩流程迁移到其他NLP模型。
2. 本教程使用的是[DynaBERT-Dynamic BERT with Adaptive Width and Depth](https://arxiv.org/abs/2004.04037)中的训练策略。把原始模型作为超网络中最大的子模型,原始模型包括多个相同大小的Transformer Block。在每次训练前会选择当前轮次要训练的子模型,每个子模型包含多个相同大小的Sub Transformer Block,每个Sub Transformer Block是选择不同宽度的Transformer Block得到的,一个Transformer Block包含一个Multi-Head Attention和一个Feed-Forward Network,Sub Transformer Block获得方式为:<br/>
&emsp;&emsp;a. 一个Multi-Head Attention层中有多个Head,每次选择不同宽度的子模型时,会同时对Head数量进行等比例减少,例如:如果原始模型中有12个Head,本次训练选择的模型是宽度为原始宽度75%的子模型,则本次训练中所有Transformer Block的Head数量为9。<br/>
&emsp;&emsp;b. Feed-Forward Network层中Linear的参数大小进行等比例减少,例如:如果原始模型中FFN层的特征维度为3072,本次训练选择的模型是宽度为原始宽度75%的子模型,则本次训练中所有Transformer Block中FFN层的特征维度为2304。
## 整体原理介绍
1. 首先对预训练模型的参数和head根据其重要性进行重排序,把重要的参数和head排在参数的前侧,保证训练过程中的参数裁剪不会裁剪掉这些重要的参数。参数的重要性计算是先使用dev数据计算一遍每个参数的梯度,然后根据梯度和参数的整体大小来计算当前参数的重要性,head的的重要性计算是通过传入一个全1的对head的mask,并计算这个mask的梯度,根据mask的梯度来判断每个Multi-Head Attention层中每个Head的重要性。
2. 使用原本的预训练模型作为蒸馏过程中的教师网络。同时定义一个超网络,这个超网络中最大的子网络的结构和教师网络的结构相同其他小的子网络是对最大网络的进行不同的宽度选择来得到的,宽度选择具体指的是网络中的参数进行裁剪,所有子网络在整个训练过程中都是参数共享的。
3. 使用重排序之后的预训练模型参数初始化超网络,并把这个超网络作为学生网络。分别为embedding层,每个transformer block层和最后的logit添加蒸馏损失。
4. 每个batch数据在训练前首先中会选择当前要训练的子网络配置(子网络配置目前仅包括对整个模型的宽度的选择),参数更新时仅会更新当前子网络计算中用到的那部分参数。
5. 通过以上的方式来优化整个超网络参数,训练完成后选择满足加速要求和精度要求的子模型。
<p align="center">
<img src="../../images/algo/ofa_bert.jpg" width="950"/><br />
整体流程图
</p>
## 基于ERNIE repo代码进行压缩
本教程基于PaddleSlim2.0及之后版本、Paddle1.8.5和ERNIE 0.0.4dev及之后版本,请确认已正确安装Paddle、PaddleSlim和ERNIE。
基于ERNIE repo中TinyERNIE的整体代码示例请参考:[TinyERNIE](../../../demo/ofa/ernie/README.md)
### 1. 定义初始网络
定义原始TinyERNIE模型并定义一个字典保存原始模型参数。普通模型转换为超网络之后,由于其组网OP的改变导致原始模型加载的参数失效,所以需要定义一个字典保存原始模型的参数并用来初始化超网络。设置'return_additional_info'参数为True,返回中间层结果,便于添加蒸馏。
```python
model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')
setattr(model, 'return_additional_info', True)
origin_weights = {}
for name, param in model.named_parameters():
origin_weights[name] = param
```
### 2. 构建超网络
定义搜索空间,并根据搜索空间把普通网络转换为超网络。
```python
# 定义搜索空间
sp_config = supernet(expand_ratio=[0.25, 0.5, 0.75, 1.0])
# 转换模型为超网络
model = Convert(sp_config).convert(model)
paddleslim.nas.ofa.utils.set_state_dict(model, origin_weights)
```
### 3. 定义教师网络
调用paddlenlp中的接口直接构造教师网络。设置'return_additional_info'参数为True,返回中间层结果,便于添加蒸馏。
```python
teacher_model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='teacher')
setattr(teacher_model, 'return_additional_info', True)
```
### 4. 配置蒸馏相关参数
需要配置的参数包括教师模型实例。TinyERNIE模型定义的时候会返回隐藏层和Embedding层的计算结果,所以直接利用返回值进行网络蒸馏。
```python
default_distill_config = {
'teacher_model': teacher_model
}
distill_config = DistillConfig(**default_distill_config)
```
### 5. 定义Once-For-All模型
普通模型和蒸馏相关配置传给OFA接口,自动添加蒸馏过程并把超网络训练方式转为OFA训练方式。
```python
ofa_model = paddleslim.nas.ofa.OFA(model, distill_config=distill_config)
```
### 6. 计算神经元和head的重要性并根据其重要性重排序参数
基于Paddle 1.8.5实现的重要性计算代码位于:[importance.py](../../../demo/ofa/ernie/ernie_supernet/importance.py)
```python
head_importance, neuron_importance = compute_neuron_head_importance(
args,
ofa_model.model,
dev_ds,
place,
model_cfg)
reorder_neuron_head(ofa_model.model, head_importance, neuron_importance)
```
### 7. 传入当前OFA训练所处的阶段
```python
ofa_model.set_epoch(epoch)
ofa_model.set_task('width')
```
### 8. 传入网络相关配置,开始训练
本示例使用DynaBERT的方式进行超网络训练。
```python
width_mult_list = [1.0, 0.75, 0.5, 0.25]
lambda_logit = 0.1
# paddle 2.0rc1之前版本的动态图模型梯度不会自动累加,需要自定义一个dict保存每个模型的梯度,自行进行梯度累加
accumulate_gradients = dict()
for param in opt._parameter_list:
accumulate_gradients[param.name] = 0.0
for width_mult in width_mult_list:
net_config = paddleslim.nas.ofa.utils.dynabert_config(ofa_model, width_mult)
ofa_model.set_net_config(net_config)
student_output, teacher_output = ofa_model(ids, sids, labels=label,
num_layers=model_cfg['num_hidden_layers'])
loss, student_logit, student_reps = student_output[
0], student_output[1], student_output[2]['hiddens']
teacher_logit, teacher_reps = teacher_output[
1], teacher_output[2]['hiddens']
logit_loss = soft_cross_entropy(student_logits, teacher_logits.detach())
rep_loss = 0.0
for stu_rep, tea_rep in zip(student_reps, teacher_reps):
tmp_loss = L.mse_loss(stu_rep, tea_rep.detach())
rep_loss += tmp_loss
loss = rep_loss + lambda_logit * logit_loss
loss.backward()
param_grads = opt.backward(loss)
# 梯度累加
for param in opt._parameter_list:
accumulate_gradients[param.name] += param.gradient()
# 利用累加后的梯度更新模型
for k, v in param_grads:
assert k.name in accumulate_gradients.keys(
), "{} not in accumulate_gradients".format(k.name)
v.set_value(accumulate_gradients[k.name])
opt.apply_optimize(
loss, startup_program=None, params_grads=param_grads)
ofa_model.model.clear_gradients()
```
---
**NOTE**
由于在计算head的重要性时会利用一个mask来收集梯度,所以需要通过monkey patch的方式重新实现一下TinyERNIE中一些相关类的forward函数。具体实现的forward可以参考:[model_ernie_supernet.py](../../../demo/ofa/ernie/ernie_supernet/modeling_ernie_supernet.py)
---
# BERT模型压缩教程
1. 本教程是对BERT模型进行压缩的原理介绍。并以PaddleNLP repo中BERT-base模型为例,说明如何快速把整体压缩流程迁移到其他NLP模型。
2. 本教程使用的是[DynaBERT-Dynamic BERT with Adaptive Width and Depth](https://arxiv.org/abs/2004.04037)中的训练策略。把原始模型作为超网络中最大的子模型,原始模型包括多个相同大小的Transformer Block。在每次训练前会选择当前轮次要训练的子模型,每个子模型包含多个相同大小的Sub Transformer Block,每个Sub Transformer Block是选择不同宽度的Transformer Block得到的,一个Transformer Block包含一个Multi-Head Attention和一个Feed-Forward Network,Sub Transformer Block获得方式为:<br/>
&emsp;&emsp;a. 一个Multi-Head Attention层中有多个Head,每次选择不同宽度的子模型时,会同时对Head数量进行等比例减少,例如:如果原始模型中有12个Head,本次训练选择的模型是宽度为原始宽度75%的子模型,则本次训练中所有Transformer Block的Head数量为9。<br/>
&emsp;&emsp;b. Feed-Forward Network层中Linear的参数大小进行等比例减少,例如:如果原始模型中FFN层的特征维度为3072,本次训练选择的模型是宽度为原始宽度75%的子模型,则本次训练中所有Transformer Block中FFN层的特征维度为2304。
## 整体原理介绍
1. 首先对预训练模型的参数和head根据其重要性进行重排序,把重要的参数和head排在参数的前侧,保证训练过程中的参数裁剪不会裁剪掉这些重要的参数。参数的重要性计算是先使用dev数据计算一遍每个参数的梯度,然后根据梯度和参数的整体大小来计算当前参数的重要性,head的的重要性计算是通过传入一个全1的对head的mask,并计算这个mask的梯度,根据mask的梯度来判断每个Multi-Head Attention层中每个Head的重要性。
2. 使用原本的预训练模型作为蒸馏过程中的教师网络。同时定义一个超网络,这个超网络中最大的子网络的结构和教师网络的结构相同其他小的子网络是对最大网络的进行不同的宽度选择来得到的,宽度选择具体指的是网络中的参数进行裁剪,所有子网络在整个训练过程中都是参数共享的。
3. 使用重排序之后的预训练模型参数初始化超网络,并把这个超网络作为学生网络。分别为embedding层,每个transformer block层和最后的logit添加蒸馏损失。
4. 每个batch数据在训练前首先中会选择当前要训练的子网络配置(子网络配置目前仅包括对整个模型的宽度的选择),参数更新时仅会更新当前子网络计算中用到的那部分参数。
5. 通过以上的方式来优化整个超网络参数,训练完成后选择满足加速要求和精度要求的子模型。
<p align="center">
<img src="../../images/algo/ofa_bert.jpg" width="950"/><br />
整体流程图
</p>
## 基于PaddleNLP repo代码进行压缩
本教程基于PaddleSlim2.0及之后版本、Paddle2.0rc1及之后版本和PaddleNLP2.0beta及之后版本,请确认已正确安装Paddle、PaddleSlim和PaddleNLP。
基于PaddleNLP repo中BERT-base的整体代码示例请参考:[BERT-base](../../../demo/ofa/bert/README.md)
### 1. 定义初始网络
定义原始BERT-base模型并定义一个字典保存原始模型参数。普通模型转换为超网络之后,由于其组网OP的改变导致原始模型加载的参数失效,所以需要定义一个字典保存原始模型的参数并用来初始化超网络。
```python
model = BertForSequenceClassification.from_pretrained('bert', num_classes=2)
origin_weights = {}
for name, param in model.named_parameters():
origin_weights[name] = param
```
### 2. 构建超网络
定义搜索空间,并根据搜索空间把普通网络转换为超网络。
```python
# 定义搜索空间
sp_config = supernet(expand_ratio=[0.25, 0.5, 0.75, 1.0])
# 转换模型为超网络
model = Convert(sp_config).convert(model)
paddleslim.nas.ofa.utils.set_state_dict(model, origin_weights)
```
### 3. 定义教师网络
调用paddlenlp中的接口直接构造教师网络。
```python
teacher_model = BertForSequenceClassification.from_pretrained('bert', num_classes=2)
```
### 4. 配置蒸馏相关参数
需要配置的参数包括教师模型实例;需要添加蒸馏的层,在教师网络和学生网络的Embedding层和每一个Tranformer Block层之间添加蒸馏损失,中间层的蒸馏损失使用默认的MSE损失函数;配置'lambda_distill'参数表示整体蒸馏损失的缩放比例。
```python
mapping_layers = ['bert.embeddings']
for idx in range(model.bert.config['num_hidden_layers']):
mapping_layers.append('bert.encoder.layers.{}'.format(idx))
default_distill_config = {
'lambda_distill': 0.1,
'teacher_model': teacher_model,
'mapping_layers': mapping_layers,
}
distill_config = DistillConfig(**default_distill_config)
```
### 5. 定义Once-For-All模型
普通模型和蒸馏相关配置传给OFA接口,自动添加蒸馏过程并把超网络训练方式转为OFA训练方式。
```python
ofa_model = paddleslim.nas.ofa.OFA(model, distill_config=distill_config)
```
### 6. 计算神经元和head的重要性并根据其重要性重排序参数
```python
head_importance, neuron_importance = utils.compute_neuron_head_importance(
'sst-2',
ofa_model.model,
dev_data_loader,
num_layers=model.bert.config['num_hidden_layers'],
num_heads=model.bert.config['num_attention_heads'])
reorder_neuron_head(ofa_model.model, head_importance, neuron_importance)
```
### 7. 传入当前OFA训练所处的阶段
```python
ofa_model.set_epoch(epoch)
ofa_model.set_task('width')
```
### 8. 传入网络相关配置,开始训练
本示例使用DynaBERT的方式进行超网络训练。
```python
width_mult_list = [1.0, 0.75, 0.5, 0.25]
lambda_logit = 0.1
for width_mult in width_mult_list:
net_config = paddleslim.nas.ofa.utils.dynabert_config(ofa_model, width_mult)
ofa_model.set_net_config(net_config)
logits, teacher_logits = ofa_model(input_ids, segment_ids, attention_mask=[None, None])
rep_loss = ofa_model.calc_distill_loss()
logit_loss = soft_cross_entropy(logits, teacher_logits.detach())
loss = rep_loss + lambda_logit * logit_loss
loss.backward()
optimizer.step()
lr_scheduler.step()
ofa_model.model.clear_gradients()
```
---
**NOTE**
由于在计算head的重要性时会利用一个mask来收集梯度,所以需要通过monkey patch的方式重新实现一下BERT的forward函数。示例如下:
```python
from paddlenlp.transformers import BertModel
def bert_forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=[None, None]):
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
if attention_mask[0] is None:
attention_mask[0] = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output
BertModel.forward = bert_forward
```
---
......@@ -18,6 +18,6 @@ from .convert_super import supernet
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import *
from .layers_old import *
else:
from .layers_new import *
from .layers import *
......@@ -24,15 +24,15 @@ if pd_ver == 185:
import paddle.fluid.dygraph.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
from paddle.fluid import ParamAttr
from .layers import *
from . import layers
from .layers_old import *
from . import layers_old as layers
Layer = paddle.fluid.dygraph.Layer
else:
import paddle.nn as nn
from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
from paddle import ParamAttr
from .layers_new import *
from . import layers_new as layers
from .layers import *
from . import layers
Layer = paddle.nn.Layer
_logger = get_logger(__name__, level=logging.INFO)
......@@ -43,6 +43,17 @@ WEIGHT_LAYER = ['conv', 'linear', 'embedding']
class Convert:
"""
Convert network to the supernet according to the search space.
Parameters:
context(paddleslim.nas.ofa.supernet): search space defined by the user.
Examples:
.. code-block:: python
from paddleslim.nas.ofa import supernet, Convert
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
convert = Convert(sp_net_config)
"""
def __init__(self, context):
self.context = context
......@@ -63,6 +74,17 @@ class Convert:
layer._bias_attr.name = 'super_' + layer._bias_attr.name
def convert(self, network):
"""
The function to convert the network to a supernet.
Parameters:
network(paddle.nn.Layer|list(paddle.nn.Layer)): instance of the model or list of instance of layers.
Examples:
.. code-block:: python
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa import supernet, Convert
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
convert = Convert(sp_net_config).convert(mobilenet_v1())
"""
# search the first and last weight layer, don't change out channel of the last weight layer
# don't change in channel of the first weight layer
model = []
......@@ -641,6 +663,14 @@ class Convert:
class supernet:
"""
Search space of the network.
Parameters:
kernel_size(list|tuple, optional): search space for the kernel size of the Conv2D.
expand_ratio(list|tuple, optional): the search space for the expand ratio of the number of channels of Conv2D, the expand ratio of the output dimensions of the Embedding or Linear, which means this parameter get the number of channels of each OP in the converted super network based on the the channels of each OP in the original model, so this parameter The length is 1. Just set one between this parameter and ``channel``.
channel(list|tuple, optional): the search space for the number of channels of Conv2D, the output dimensions of the Embedding or Linear, this parameter directly sets the number of channels of each OP in the super network, so the length of this parameter needs to be the same as the total number that of Conv2D, Embedding, and Linear included in the network. Just set one between this parameter and ``expand_ratio``.
"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
......
此差异已折叠。
......@@ -20,10 +20,10 @@ import paddle.fluid as fluid
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import BaseBlock, SuperConv2D, SuperLinear
from .layers_old import BaseBlock, SuperConv2D, SuperLinear
Layer = paddle.fluid.dygraph.Layer
else:
from .layers_new import BaseBlock, SuperConv2D, SuperLinear
from .layers import BaseBlock, SuperConv2D, SuperLinear
Layer = paddle.nn.Layer
from .utils.utils import search_idx
from ...common import get_logger
......@@ -32,16 +32,40 @@ _logger = get_logger(__name__, level=logging.INFO)
__all__ = ['OFA', 'RunConfig', 'DistillConfig']
RunConfig = namedtuple('RunConfig', [
'train_batch_size', 'n_epochs', 'save_frequency', 'eval_frequency',
'init_learning_rate', 'total_images', 'elastic_depth', 'dynamic_batch_size'
])
RunConfig = namedtuple(
'RunConfig',
[
# int, batch_size in training, used to get current epoch, default: None
'train_batch_size',
# list, the number of epoch of every task in training, default: None
'n_epochs',
# list, initial learning rate of every task in traning, NOT used now. Default: None.
'init_learning_rate',
# int, total images of train dataset, used to get current epoch, default: None
'total_images',
# list, elactic depth of the model in training, default: None
'elastic_depth',
# list, the number of sub-network to train per mini-batch data, used to get current epoch, default: None
'dynamic_batch_size'
])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
DistillConfig = namedtuple('DistillConfig', [
'lambda_distill', 'teacher_model', 'mapping_layers', 'teacher_model_path',
'distill_fn', 'mapping_op'
])
DistillConfig = namedtuple(
'DistillConfig',
[
# float, lambda scale of distillation loss, default: None.
'lambda_distill',
# instance of model, instance of teacher model, default: None.
'teacher_model',
# list(str), name of the layers which need a distillation, default: None.
'mapping_layers',
# str, the path of teacher pretrained model, default: None.
'teacher_model_path',
# instance of loss layer, the loss function used in distillation, if set to None, use mse_loss default, default: None.
'distill_fn',
# str, define which op append between teacher model and student model used in distillation, choice in ['conv', 'linear', None], default: None.
'mapping_op'
])
DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
......@@ -89,15 +113,31 @@ class OFABase(Layer):
class OFA(OFABase):
"""
Convert the training progress to the Once-For-All training progress, a detailed description in the paper: `Once-for-All: Train One Network and Specialize it for Efficient Deployment<https://arxiv.org/abs/1908.09791>`_ . This paper propose a training propgress named progressive shrinking (PS), which means we start with training the largest neural network with the maximum kernel size (i.e., 7), depth (i.e., 4), and width (i.e., 6). Next, we progressively fine-tune the network to support smaller sub-networks by gradually adding them into the sampling space (larger sub-networks may also be sampled). Specifically, after training the largest network, we first support elastic kernel size which can choose from {3, 5, 7} at each layer, while the depth and width remain the maximum values. Then, we support elastic depth and elastic width sequentially.
Parameters:
model(paddle.nn.Layer): instance of model.
run_config(paddleslim.ofa.RunConfig, optional): config in ofa training, can reference `<>`_ . Default: None.
distill_config(paddleslim.ofa.DistillConfig, optional): config of distilltion in ofa training, can reference `<>`_. Default: None.
elastic_order(list, optional): define the training order, if it set to None, use the default order in the paper. Default: None.
train_full(bool, optional): whether to train the largest sub-network only. Default: False.
Examples:
.. code-block:: python
from paddlslim.nas.ofa import OFA
ofa_model = OFA(model)
"""
def __init__(self,
model,
run_config=None,
net_config=None,
distill_config=None,
elastic_order=None,
train_full=False):
super(OFA, self).__init__(model)
self.net_config = net_config
self.net_config = None
self.run_config = run_config
self.distill_config = distill_config
self.elastic_order = elastic_order
......@@ -278,12 +318,29 @@ class OFA(OFABase):
self.layers, sample_type=sample_type, task=task, phase=phase)
return config
def set_task(self, task=None, phase=None):
def set_task(self, task, phase=None):
"""
set task in the ofa training progress.
Parameters:
task(list(str)|str): spectial task in training progress.
phase(int, optional): the search space is gradually increased, use this parameter to spectial the phase in current task, if set to None, means use the whole search space in training progress. Default: None.
Examples:
.. code-block:: python
ofa_model.set_task('width')
"""
self.manual_set_task = True
self.task = task
self.phase = phase
def set_epoch(self, epoch):
"""
set epoch in the ofa training progress.
Parameters:
epoch(int): spectial epoch in training progress.
Examples:
.. code-block:: python
ofa_model.set_epoch(3)
"""
self.epoch = epoch
def _progressive_shrinking(self):
......@@ -302,6 +359,12 @@ class OFA(OFABase):
return self._sample_config(task=self.task, phase=phase_idx)
def calc_distill_loss(self):
"""
Calculate distill loss if there are distillation.
Examples:
.. code-block:: python
dis_loss = ofa_model.calc_distill_loss()
"""
losses = []
assert len(self.netAs) > 0
for i, netA in enumerate(self.netAs):
......@@ -319,6 +382,8 @@ class OFA(OFABase):
else:
Sact = Sact
Sact = Sact[0] if isinstance(Sact, tuple) else Sact
Tact = Tact[0] if isinstance(Tact, tuple) else Tact
if self.distill_config.distill_fn == None:
loss = fluid.layers.mse_loss(Sact, Tact.detach())
else:
......@@ -337,6 +402,15 @@ class OFA(OFABase):
pass
def set_net_config(self, net_config):
"""
Set the config of the special sub-network to be trained.
Parameters:
net_config(dict): special the config of sug-network.
Examples:
.. code-block:: python
config = ofa_model.current_config
ofa_model.set_net_config(config)
"""
self.net_config = net_config
def forward(self, *inputs, **kwargs):
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from .utils import *
from .special_config import *
from .utils import get_paddle_version
pd_ver = get_paddle_version()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
__all__ = ['dynabert_config']
def dynabert_config(model, width_mult, depth_mult=1.0):
new_config = dict()
block_num = np.floor((len(model.layers.items()) - 3) / 6)
block_name = block_num * 6 + 2
def fix_exp(idx):
if (idx - 3) % 6 == 0 or (idx - 5) % 6 == 0:
return True
return False
for idx, (block_k, block_v) in enumerate(model.layers.items()):
if isinstance(block_v, dict) and len(block_v.keys()) != 0:
name, name_idx = block_k.split('_'), int(block_k.split('_')[1])
if fix_exp(name_idx) or 'emb' in block_k or idx >= block_name:
block_v['expand_ratio'] = 1.0
else:
block_v['expand_ratio'] = width_mult
if block_k == 'depth':
block_v = depth_mult
new_config[block_k] = block_v
return new_config
......@@ -22,7 +22,7 @@ from paddle.nn import ReLU
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers_new import Block, SuperSeparableConv2D
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
class ModelConv(nn.Layer):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import numpy as np
import unittest
import paddle
import paddle.nn as nn
from paddle.nn import ReLU
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import *
class ModelCase1(nn.Layer):
def __init__(self):
super(ModelCase1, self).__init__()
models = [SuperConv2D(3, 4, 3, bias_attr=False)]
models += [SuperConv2D(4, 4, 3, groups=4)]
models += [SuperConv2D(4, 4, 3, groups=2)]
models += [SuperConv2DTranspose(4, 4, 3, bias_attr=False)]
models += [SuperConv2DTranspose(4, 4, 3, groups=4)]
models += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
models += [SuperConv2DTranspose(4, 4, 3, groups=2)]
models += [
SuperSeparableConv2D(
4,
4,
1,
padding=1,
bias_attr=False,
candidate_config={'expand_ratio': (1.0, 2.0)}),
]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class TestCase(unittest.TestCase):
def setUp(self):
self.model = ModelCase1()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
ofa_model = OFA(self.model)
out = self.model(self.data)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import numpy as np
import unittest
import paddle
import paddle.nn as nn
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA
from paddleslim.nas.ofa.layers_old import *
class ModelCase1(nn.Layer):
def __init__(self):
super(ModelCase1, self).__init__()
models = [SuperConv2D(3, 4, 3, bias_attr=False)]
models += [
SuperConv2D(
4,
4,
7,
candidate_config={
'expand_ratio': (0.5, 1.0),
'kernel_size': (3, 5, 7)
},
transform_kernel=True)
]
models += [SuperConv2D(4, 4, 3, groups=4)]
models += [SuperConv2D(4, 4, 3, groups=2)]
models += [SuperBatchNorm(4)]
models += [SuperConv2DTranspose(4, 4, 3, bias_attr=False)]
models += [
SuperConv2DTranspose(
4,
4,
7,
candidate_config={
'expand_ratio': (0.5, 1.0),
'kernel_size': (3, 5, 7)
},
transform_kernel=True)
]
models += [SuperConv2DTranspose(4, 4, 3, groups=4)]
models += [SuperInstanceNorm(4)]
models += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
models += [SuperConv2DTranspose(4, 4, 3, groups=2)]
models += [
SuperSeparableConv2D(
4,
4,
1,
padding=1,
bias_attr=False,
candidate_config={'expand_ratio': (0.5, 1.0)}),
]
models += [
SuperSeparableConv2D(
4, 4, 1, padding=1, candidate_config={'channel': (2, 4)}),
]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class ModelCase2(nn.Layer):
def __init__(self):
super(ModelCase2, self).__init__()
models = [
SuperEmbedding(
size=(64, 64), candidate_config={'expand_ratio': (0.5, 1.0)})
]
models += [
SuperLinear(
64, 64, candidate_config={'expand_ratio': (0.5, 1.0)})
]
models += [SuperLayerNorm(64)]
models += [SuperLinear(64, 64, candidate_config={'channel': (32, 64)})]
models += [
SuperLinear(
64, 64, bias_attr=False,
candidate_config={'channel': (32, 64)})
]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class ModelCase3(nn.Layer):
def __init__(self):
super(ModelCase3, self).__init__()
self.conv1 = SuperConv2D(
3,
4,
7,
candidate_config={'kernel_size': (3, 5, 7)},
transform_kernel=True)
self.conv2 = SuperConv2DTranspose(
4,
4,
7,
candidate_config={'kernel_size': (3, 5, 7)},
transform_kernel=True)
def forward(self, inputs):
inputs = self.conv1(inputs, kernel_size=3)
inputs = self.conv2(inputs, kernel_size=3)
return inputs
class TestCase(unittest.TestCase):
def setUp(self):
self.model = ModelCase1()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
ofa_model = OFA(self.model)
out = self.model(self.data)
class TestCase2(TestCase):
def setUp(self):
self.model = ModelCase2()
data_np = np.random.random((64, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
class TestCase3(TestCase):
def setUp(self):
self.model = ModelCase3()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
if __name__ == '__main__':
unittest.main()
......@@ -20,40 +20,36 @@ import paddle
import paddle.nn as nn
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.utils import compute_neuron_head_importance, reorder_head, reorder_neuron, set_state_dict
from paddleslim.nas.ofa.utils import compute_neuron_head_importance, reorder_head, reorder_neuron, set_state_dict, dynabert_config
from paddleslim.nas.ofa import OFA
class TestModel(nn.Layer):
def __init__(self):
super(TestModel, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(
312,
12,
1024,
dropout=0.1,
activation='gelu',
attn_dropout=0.1,
act_dropout=0)
self.encoder = nn.TransformerEncoder(encoder_layer, 3)
self.fc = nn.Linear(312, 3)
def forward(self, input_ids, segment_ids, attention_mask=[None, None]):
src = input_ids + segment_ids
out = self.encoder(src, attention_mask)
out = self.fc(out[:, 0])
return out
class TestComputeImportance(unittest.TestCase):
def setUp(self):
self.model = self.init_model()
self.model = TestModel()
self.data_loader = self.init_data()
def init_model(self):
class TestModel(nn.Layer):
def __init__(self):
super(TestModel, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(
312,
12,
1024,
dropout=0.1,
activation='gelu',
attn_dropout=0.1,
act_dropout=0)
self.encoder = nn.TransformerEncoder(encoder_layer, 3)
self.fc = nn.Linear(312, 3)
def forward(self,
input_ids,
segment_ids,
attention_mask=[None, None]):
src = input_ids + segment_ids
out = self.encoder(src, attention_mask)
out = self.fc(out[:, 0])
return out
return TestModel()
def init_data(self):
batch_size = 16
hidden_size = 312
......@@ -67,8 +63,7 @@ class TestComputeImportance(unittest.TestCase):
paddle.to_tensor(labels)), )
return data
def reorder_reorder_neuron_head(self, model, head_importance,
neuron_importance):
def reorder_neuron_head(self, model, head_importance, neuron_importance):
# reorder heads and ffn neurons
for layer, current_importance in enumerate(neuron_importance):
# reorder heads
......@@ -89,8 +84,7 @@ class TestComputeImportance(unittest.TestCase):
num_heads=12)
assert (len(head_importance) == 3)
assert (len(neuron_importance) == 3)
self.reorder_reorder_neuron_head(self.model, head_importance,
neuron_importance)
self.reorder_neuron_head(self.model, head_importance, neuron_importance)
class TestComputeImportanceCase1(TestComputeImportance):
......@@ -125,5 +119,14 @@ class TestSetStateDict(unittest.TestCase):
set_state_dict(sp_model, self.origin_weights)
class TestSpecialConfig(unittest.TestCase):
def test_dynabert(self):
self.model = TestModel()
sp_net_config = supernet(expand_ratio=[0.5, 1.0])
self.model = Convert(sp_net_config).convert(self.model)
ofa_model = OFA(self.model)
config = dynabert_config(ofa_model, 0.5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册