未验证 提交 3fde095b 编写于 作者: C ceci3 提交者: GitHub

[distill] support wrap functional with class (#897)

上级 2ed20815
......@@ -10,19 +10,19 @@
TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型`tinybert-6l-768d-v2`进行初始化。
<p align="center">
<img src="./imgs/tinybert.png" width="950"/><br />
<img src="./tinybert.png" width="950"/><br />
TinyBERT蒸馏流程图
</p>
在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。
知识的蒸馏通常是通过让学生模型学习相关的蒸馏损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。
知识的蒸馏通常是通过让学生模型学习相关的蒸馏损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。
由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。
实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。
需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2``tinybert-4l-312d-v2`这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d``tinybert-4l-312d``tinybert-6l-768d-zh``tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。
需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2``tinybert-4l-312d-v2`这两个v2版本的预训练模型中从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d``tinybert-4l-312d``tinybert-6l-768d-zh``tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。
### 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。
......
- DistillConfig:
loss_function: MSELoss
- loss_function: MSELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
align_params:
align_type: linear
in_channel: 768
out_channel: 768
layers:
- layers_name: ['tinybert.embeddings', 'bert.embeddings']
- layers_name: ['tinybert.encoder.layers.0', 'bert.encoder.layers.1']
- layers_name: ['tinybert.encoder.layers.1', 'bert.encoder.layers.3']
......@@ -12,9 +16,22 @@
- layers_name: ['tinybert.encoder.layers.3', 'bert.encoder.layers.7']
- layers_name: ['tinybert.encoder.layers.4', 'bert.encoder.layers.9']
- layers_name: ['tinybert.encoder.layers.5', 'bert.encoder.layers.11']
- layers_name: ['tinybert.encoder.layers.0.self_attn', 'bert.encoder.layers.1.self_attn']
- layers_name: ['tinybert.encoder.layers.1.self_attn', 'bert.encoder.layers.3.self_attn']
- layers_name: ['tinybert.encoder.layers.2.self_attn', 'bert.encoder.layers.5.self_attn']
- layers_name: ['tinybert.encoder.layers.3.self_attn', 'bert.encoder.layers.7.self_attn']
- layers_name: ['tinybert.encoder.layers.4.self_attn', 'bert.encoder.layers.9.self_attn']
- layers_name: ['tinybert.encoder.layers.5.self_attn', 'bert.encoder.layers.11.self_attn']
- loss_function: MSELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
layers:
- layers_name: ['tinybert.encoder.layers.0.self_attn.wrap_fn_softmax_0', 'bert.encoder.layers.1.self_attn.wrap_fn_softmax_2']
io: ['input', 'input']
- layers_name: ['tinybert.encoder.layers.1.self_attn.wrap_fn_softmax_2', 'bert.encoder.layers.3.self_attn.wrap_fn_softmax_6']
io: ['input', 'input']
- layers_name: ['tinybert.encoder.layers.2.self_attn.wrap_fn_softmax_4', 'bert.encoder.layers.5.self_attn.wrap_fn_softmax_10']
io: ['input', 'input']
- layers_name: ['tinybert.encoder.layers.3.self_attn.wrap_fn_softmax_6', 'bert.encoder.layers.7.self_attn.wrap_fn_softmax_14']
io: ['input', 'input']
- layers_name: ['tinybert.encoder.layers.4.self_attn.wrap_fn_softmax_8', 'bert.encoder.layers.9.self_attn.wrap_fn_softmax_18']
io: ['input', 'input']
- layers_name: ['tinybert.encoder.layers.5.self_attn.wrap_fn_softmax_10', 'bert.encoder.layers.11.self_attn.wrap_fn_softmax_22']
io: ['input', 'input']
- DistillConfig:
loss_function: CELoss
- loss_function: CELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
layers:
- layers_name: ['classifier', 'classifier']
temperature: 1.0
......@@ -17,4 +17,6 @@ python3.7 task_distill.py \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--distill_config ./distill_stage1.yaml \
--device gpu
......@@ -157,11 +157,6 @@ def parse_args():
"--use_aug",
action="store_true",
help="Whether to use augmentation data to train.", )
parser.add_argument(
"--intermediate_distill",
action="store_true",
help="Whether distilling intermediate layers. If False, it means prediction layer distillation.",
)
parser.add_argument(
"--weight_decay",
default=0.0,
......@@ -349,7 +344,6 @@ def do_train(args):
teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type]
teacher = teacher_model_class.from_pretrained(
args.teacher_path, num_classes=num_classes)
teacher.eval()
if paddle.distributed.get_world_size() > 1:
student = paddle.DataParallel(student, find_unused_parameters=True)
......@@ -368,10 +362,20 @@ def do_train(args):
lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate,
num_training_steps, warmup)
### step1: load distill config
assert os.path.exists(
args.distill_config), "distill file {} not exist.".format(
args.distill_config)
### step2: wrap the student model and teacher model by paddleslim.dygraph.dist.Distill
### the distill config need to be passed into it.
distill_model = Distill(
args.distill_config, students=[student], teachers=[teacher])
### step3: add parameter created by align op to optimizer
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in student.named_parameters()
p.name for n, p in distill_model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
......@@ -379,7 +383,7 @@ def do_train(args):
beta1=0.9,
beta2=0.999,
epsilon=args.adam_epsilon,
parameters=student.parameters(),
parameters=distill_model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
......@@ -390,17 +394,11 @@ def do_train(args):
tic_train = time.time()
best_res = 0.0
assert os.path.exists(
args.distill_config), "distill file {} not exist.".format(
args.distill_config)
distill_model = Distill(
args.distill_config, student_models=[student],
teacher_models=[teacher])
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch
### step4: call distill_model instead of call student model and teacher model independently.
loss, _, _ = distill_model(input_ids, segment_ids)
loss.backward()
......
......@@ -22,9 +22,12 @@ from .server import Server
from .client import Client
from .meter import AvgrageMeter
from .analyze_helper import VarCollector
from paddleslim.common import wrapper_function
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController', 'VarCollector'
]
__all__ += wrapper_function.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['Counter', 'init_index', 'functional2layer']
class Counter:
"""
limit the number of function calls.
"""
def __init__(self, times=1):
self.calls = 0
self.times = times
def __call__(self, func):
def counter_wrapper(*args, **kwargs):
func(*args, **kwargs)
self.calls += 1
assert self.calls <= self.times, "function {} only allow to call {} times.".format(
func.__name__, self.times)
return counter_wrapper
class FuncWrapper(nn.Layer):
"""
"""
def __init__(self, functional):
super(FuncWrapper, self).__init__()
self.fn = functional
def forward(self, *x, **kwargs):
return self.fn(*x, **kwargs)
def convert_fn(fn):
def new_fn(*x, **kwargs):
global global_idx
model = inspect.currentframe().f_back.f_locals['self']
### TODO(ceci3): length of sublayers is 0 means not call a nn.Layer in __init__ function.
### this condition maybe not rigorous, need to change it later.
### model.training set to False is to avoid only eval student model.
if len(model.sublayers()) == 0 or model.training == False:
result = eval('F.origin_{}'.format(fn.__name__))(*x, **kwargs)
return result
else:
if getattr(model, 'wrap_fn_{}_{}'.format(fn.__name__, global_idx),
None) == None:
setattr(model, 'wrap_fn_{}_{}'.format(fn.__name__, global_idx),
FuncWrapper(fn))
result = getattr(model, 'wrap_fn_{}_{}'.format(
fn.__name__, global_idx))(*x, **kwargs)
global_idx += 1
return result
return new_fn
def init_index():
global global_idx
global_idx = 0
@Counter()
def functional2layer():
"""
Wrap the function in paddle.nn.functional with class inherited from paddle.nn.Layer.
The purpose of this operation is to get the output of paddle.nn.functional in the model.
For example:
```python
class Model(nn.Layer):
def __init__(self):
self.fc = nn.Linear(12, 16)
def forward(x):
softmax_out = nn.functional.softmax(x)
fc_out = self.fc(softmax_out)
relu_out = nn.functional.relu(fc_out)
return relu_out
```
Before call the ```paddleslim.common.functional2layer``` function, we can get the output
of this model and the output of self.fc function through ```register_forward_post_hook``` in the paddle.
And the ```register_forward_post_hook``` interface can only used to get the output of class which is
inherited from paddle.nn.Layer. Please reference to:
[register_forward_post_hook](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Layer_cn.html#register_forward_post_hook)
After call the ```paddleslim.common.functional2layer```, the model above will be converted to
a new model:
```python
class Model(nn.Layer):
def __init__(self):
self.fc = nn.Linear(12, 16)
self.wrap_fn_softmax_0 = FuncWrapper()
self.wrap_fn_relu_1 = FuncWrapper()
def forward(x):
softmax_out = self.wrap_fn_softmax_0(x)
fc_out = self.fc(softmax_out)
relu_out = self.wrap_fn_relu_1(fc_out)
return relu_out
```
after this convert operation, we can get the output of softmax through ```register_forward_post_hook```.
The convert operation can applies to the layers in paddle.nn.functional.py.
"""
init_index()
not_convert = ['linear', 'conv1d', 'conv1d_transpose', \
'conv2d', 'conv2d_transpose', 'conv3d', \
'conv3d_transpose', 'one_hot', 'embedding']
for f in dir(F):
if not f.startswith('__') and f not in not_convert and not f.startswith(
'origin_'):
setattr(F, 'origin_{}'.format(f), eval('F.{}'.format(f)))
if inspect.isfunction(eval('F.{}'.format(f))):
new_fn = convert_fn(eval('F.{}'.format(f)))
setattr(F, '{}'.format(f), new_fn)
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import copy
import collections
from collections import namedtuple
import numpy as np
import paddle
import paddle.nn as nn
from ...common.wrapper_function import init_index, functional2layer
from . import losses
from .losses.basic_loss import BASIC_LOSS
from .distill_helpers import yaml2config
......@@ -30,6 +32,8 @@ class LayerConfig:
model_name_pairs,
layers_name,
loss_function,
io=["output", "output"],
idx=[None, None],
weight=1.0,
temperature=1.0,
align_params=None,
......@@ -42,6 +46,8 @@ class LayerConfig:
loss_function,
BASIC_LOSS.module_dict.keys()))
self.loss_function = loss_function
self.io = io
self.idx = idx
self.weight = weight
self.temperature = temperature
self.align_params = align_params
......@@ -49,7 +55,7 @@ class LayerConfig:
setattr(self, k, v)
def _add_hooks(model, outs, hook_layers_name):
def _add_hooks(model, outs, layers_name, hook_layers_name, io='o', idx="None"):
"""
Get output by layer name.
models(nn.Layer): model need to be add hook.
......@@ -57,75 +63,89 @@ def _add_hooks(model, outs, hook_layers_name):
hook_layers_name(list): name of middle layers.
"""
def _get_activation(outs, name):
### TODO: need to support get input tensor
#outs[name] = {}
def _get_activation(outs, name, io, idx):
def get_output_hook(layer, input, output):
#outs[name]["output"] = output
#outs[name]["input"] = input
outs[name] = output
if io == 'o':
if idx == "None":
outs[name] = output
else:
outs[name] = output[idx]
else:
if idx == "None":
outs[name] = input
else:
outs[name] = input[idx]
return get_output_hook
### TODO: support DP model
for idx, (n, m) in enumerate(model.named_sublayers()):
if n in hook_layers_name:
m.register_forward_post_hook(_get_activation(outs, n))
for i, (n, m) in enumerate(model.named_sublayers()):
if n == layers_name:
hooks = m.register_forward_post_hook(
_get_activation(outs, hook_layers_name, io, idx))
return hooks
def _remove_hooks(hooks):
for hook in hooks:
hook.remove()
class Distill(nn.Layer):
"""
Distill API.
distill_configs(list(dict) | path): the list of distill config.
student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode.
teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode.
return_model_outputs(bool): whether to return model output. Default: True.
configs(list(dict) | string): the list of distill config or the path of yaml file which contain the distill config.
students(list(nn.Layer)): the list of student model, the state of student model must be training mode.
teachers(list(nn.Layer)): the list of teacher model.
convert_fn(bool): convert the functional in paddlepaddle to nn.Layer. The detail of this convert operation please
reference to ```paddleslim.common.functional2layer```. Default: True.
return_model_outputs(bool): whether to return the origin outputs of the model. If set to True, will return distill loss, the output of students and the output of teachers, the output of each part will be returned as a list. Default: True.
"""
def __init__(self,
distill_configs,
student_models,
teacher_models,
configs,
students,
teachers,
convert_fn=True,
return_model_outputs=True):
super(Distill, self).__init__()
if isinstance(student_models, nn.Layer):
student_models = [student_models]
if isinstance(teacher_models, nn.Layer):
teacher_models = [teacher_models]
for student_model in student_models:
assert student_model.training, "The student model should not be eval mode."
for teacher_model in teacher_models:
assert teacher_model.training is False, "The teacher model should be eval mode."
if isinstance(distill_configs, list):
self._distill_configs = distill_configs
elif os.path.exists(distill_configs):
if distill_configs.endswith(".yaml"):
self._distill_configs = yaml2config(distill_configs)
if convert_fn:
functional2layer()
if isinstance(students, nn.Layer):
students = [students]
if isinstance(teachers, nn.Layer):
teachers = [teachers]
if isinstance(configs, list):
self._configs = configs
elif os.path.exists(configs):
if configs.endswith(".yaml"):
self._configs = yaml2config(configs)
else:
raise NotImplementedError("distill config file type error!")
else:
raise NotImplementedError("distill config error!")
self._student_models = student_models
self._teacher_models = teacher_models
self._student_models = nn.LayerList(students)
self._teacher_models = nn.LayerList(teachers)
self._return_model_outputs = return_model_outputs
self._loss_config_list = []
for c in self._distill_configs:
self._transpose_config(c)
for c in self._configs:
unfold_layer_config = self._transpose_config(c)
self._loss_config_list.extend(unfold_layer_config)
self._hook_layers = self._extract_hook_position()
hook_layers = self._extract_hook_position()
self._hook_layers = hook_layers
# use self._loss_config_list to create all loss object
self.distill_loss = losses.CombinedLoss(self._loss_config_list)
self._output_tensor_dict = self._prepare_outputs()
self._output_tensor_dict = self._prepare_outputs(hook_layers)
self._check_hook_output = False
def parameters(self):
params = []
for s_model in self._student_models:
params.extend(s_model.parameters())
return params
return self._student_models.parameters() + self.distill_loss.parameters(
)
def _extract_hook_position(self):
""" extrat hook position according to config"""
......@@ -145,6 +165,7 @@ class Distill(nn.Layer):
def _transpose_config(self, config):
""" Transpose config to loss needed """
unfold_config = []
global_config = {}
if 'model_name_pairs' not in config:
global_config['model_name_pairs'] = [['student_0', 'teacher_0']]
......@@ -159,57 +180,114 @@ class Distill(nn.Layer):
global_config[key] = config[key]
for per_layer_config in config['layers']:
per_layer_config.update(global_config)
self._loss_config_list.append(
LayerConfig(**per_layer_config).__dict__)
per_layer_config.update(copy.deepcopy(global_config))
layer_config = LayerConfig(**per_layer_config).__dict__
for idx in range(len(layer_config['layers_name'])):
### slice 0 from string "input" or "output", results is "i" or "o".
postfix = '#' + layer_config['io'][idx][0] + '#' + str(
layer_config['idx'][idx])
layer_config['layers_name'][idx] += postfix
### io and idx only use to extract tensor from hook, so pop it here.
layer_config.pop('io')
layer_config.pop('idx')
unfold_config.append(layer_config)
return unfold_config
def _prepare_outputs(self):
def _prepare_outputs(self, hook_layers, in_forward=False):
"""
Add hook to get the output tensor of target layer.
"""
outputs_tensor = {}
for idx, m in enumerate(self._student_models):
hook_layers = self._hook_layers['student_{}'.format(idx)]
tmp_hook_layers = hook_layers['student_{}'.format(idx)]
stu_outs = collections.OrderedDict()
outputs_tensor['student_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, stu_outs)
m, tmp_hook_layers, stu_outs, in_forward=in_forward)
for idx, m in enumerate(self._teacher_models):
hook_layers = self._hook_layers['teacher_{}'.format(idx)]
tmp_hook_layers = hook_layers['teacher_{}'.format(idx)]
tea_outs = collections.OrderedDict()
outputs_tensor['teacher_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, tea_outs)
m, tmp_hook_layers, tea_outs, in_forward=in_forward)
return outputs_tensor
def _prepare_hook(self, model, hook_layers, outs_dict):
def _prepare_hook(self, model, hook_layers, outs_dict, in_forward):
"""
Add hook.
"""
self.forward_hooks = []
for layer in hook_layers:
if isinstance(layer, str):
_add_hooks(model, outs_dict, layer)
tmp = layer.strip().split('#')
layer_name, io, idx = tmp[0], tmp[1], tmp[2]
if idx != "None":
idx = int(idx)
if in_forward:
if 'wrap_fn_' in layer_name:
hooks = _add_hooks(model, outs_dict, layer_name, layer, io,
idx)
self.forward_hooks.append(hooks)
else:
if 'wrap_fn_' not in layer_name:
_add_hooks(model, outs_dict, layer_name, layer, io, idx)
return outs_dict
def _useless_forward(self, *inputs, **kwargs):
for idx, student_model in enumerate(self._student_models):
### initialize global index before each forward
init_index()
student_model.forward(*inputs, **kwargs)
for idx, teacher_model in enumerate(self._teacher_models):
### initialize global index before each forward
init_index()
teacher_model.forward(*inputs, **kwargs)
def forward(self, *inputs, **kwargs):
if self._check_hook_output is False:
### the first useless forward is to convert function to class.
self._useless_forward(*inputs, **kwargs)
update_output_tensor_dict = self._prepare_outputs(
self._hook_layers, in_forward=True)
students_batch_outs = []
teachers_batch_outs = []
for idx, student_model in enumerate(self._student_models):
### initialize global index before each forward
init_index()
stu_batch_outs = student_model.forward(*inputs, **kwargs)
students_batch_outs.append(stu_batch_outs)
for idx, teacher_model in enumerate(self._teacher_models):
### initialize global index before each forward
init_index()
tea_batch_outs = teacher_model.forward(*inputs, **kwargs)
if not teacher_model.training:
tea_batch_outs = [i.detach() for i in tea_batch_outs]
teachers_batch_outs.extend(tea_batch_outs)
### update hook information.
for model, _ in self._output_tensor_dict.items():
self._output_tensor_dict[model].update(update_output_tensor_dict[
model])
if len(self._student_models) == 1:
students_batch_outs = students_batch_outs[0]
if len(self._teacher_models) == 1:
teachers_batch_outs = teachers_batch_outs[0]
if self._check_hook_output is False:
self._check_hook_output = True
for mo, hook_out in self._output_tensor_dict.items():
for hook_name, hook_value in hook_out.items():
hook_name = hook_name.strip().split('#')[0]
assert type(hook_value) is paddle.Tensor or len(\
hook_value) == 1, \
"model: {} layer: {} has more than one output/input" \
", please specific the idx of output/input.".format(mo, hook_name)
### batch is None just for now
distill_outputs = self.distill_loss(self._output_tensor_dict, None)
distill_loss = distill_outputs['loss']
_remove_hooks(self.forward_hooks)
if self._return_model_outputs:
return distill_loss, students_batch_outs, teachers_batch_outs
else:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import yaml
__all__ = ['config2yaml']
__all__ = ['config2yaml', 'yaml2config']
def yaml2config(yaml_path):
......
......@@ -38,9 +38,9 @@ class CombinedLoss(nn.Layer):
"""
def __init__(self, loss_config_list=None):
super().__init__()
super(CombinedLoss, self).__init__()
loss_config_list = copy.deepcopy(loss_config_list)
self.loss_func = []
self.loss_func = nn.LayerList()
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
......
......@@ -12,12 +12,13 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from .basic_loss import BASIC_LOSS
__all__ = ["DistillationLoss"]
__all__ = ["DistillationLoss", "ShapeAlign"]
class DistillationLoss(nn.Layer):
......@@ -44,8 +45,12 @@ class DistillationLoss(nn.Layer):
self.align_params = params.pop(
'align_params') if 'align_params' in params else None
if self.align_params is not None:
for attr, value in self.align_params.items():
setattr(self, attr, value)
if 'transpose_model' in self.align_params:
self.transpose_model = self.align_params['transpose_model']
self.align_params.pop('transpose_model')
else:
self.transpose_model = 'student'
self.align_func = ShapeAlign(**self.align_params)
self.loss_func = BASIC_LOSS.get(loss_function)(**params)
......@@ -59,6 +64,11 @@ class DistillationLoss(nn.Layer):
) == 2, "length of layers_name must be equal to 2."
out1 = out1[self.layers_name[0]]
out2 = out2[self.layers_name[1]]
if self.align_params is not None:
if self.transpose_model == 'student':
out1 = self.align_func(out1)
else:
out2 = self.align_func(out2)
if self.temperature != 1.0:
out1 = out1 / self.temperature
out2 = out2 / self.temperature
......@@ -66,3 +76,100 @@ class DistillationLoss(nn.Layer):
1], self.layers_name[0] if self.layers_name != None else "0", \
self.layers_name[1] if self.layers_name != None else "0")] = self.loss_func(out1, out2)
return loss_dict
class ShapeAlign(nn.Layer):
"""
Align the feature map between student and teacher.
Args:
align_type(str): reshape tensor by which op, choice in ['1x1conv','3x3conv','1x1conv+bn','3x3conv+bn','linear']
in_channel(int): input channel number
out_channel(int): output channel number
"""
def __init__(self, align_type, in_channel, out_channel, weight_init=None):
super(ShapeAlign, self).__init__()
self._in_channel = in_channel
self._out_channel = out_channel
assert align_type.lower() in [
'1x1conv', '3x3conv', '1x1conv+bn', '3x3conv+bn', 'linear'
], "only support 1x1conv, 3x3conv, 1x1conv+bn, 3x3conv+bn, linear for now"
bias_attr = None
if weight_init is not None:
assert 'initializer' in weight_init
init_mode = weight_init.pop('initializer')
### load transpose weight from pretrained model.
if init_mode == 'Assign':
bias = None
assert 'params_path' in weight_init
assert 'params_name' in weight_init
params_path = weight_init['params_path']
params_name = weight_init['params_name']
if isinstance(weight_init['params_name'], (list, tuple)):
assert len(weight_init['params_name']) <= 2
weight = paddle.load(params_path)[params_name[0]]
bias = paddle.load(params_path)[params_name[1]]
else:
weight = paddle.load(params_path)[params_name]
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(weight))
if bias is not None:
bias_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(bias))
else:
weight_attr = paddle.framework.ParamAttr(initializer=eval(
'paddle.nn.initializer.{}'.format(init_mode))(
**weight_init))
else:
weight_attr = None
if align_type.lower() == '1x1conv':
self.align_op = paddle.nn.Conv2D(
in_channel,
out_channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=weight_attr,
bias_attr=bias_attr)
elif align_type.lower() == '3x3conv':
self.align_op = paddle.nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
weight_attr=weight_attr,
bias_attr=bias_attr)
elif align_type.lower() == '1x1conv+bn':
self.align_op = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channel,
out_channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=weight_attr,
bias_attr=bias_attr),
paddle.nn.BatchNorm2D(out_channel))
elif align_type.lower() == '3x3conv+bn':
self.align_op = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
weight_attr=weight_attr,
bias_attr=bias_attr),
paddle.nn.BatchNorm2D(out_channel))
elif align_type.lower() == 'linear':
self.align_op = paddle.nn.Linear(
in_channel,
out_channel,
weight_attr=weight_attr,
bias_attr=bias_attr)
def forward(self, feat):
out = self.align_op(feat)
return out
......@@ -24,6 +24,7 @@ class TestImperativeDistill(unittest.TestCase):
return MobileNetV1(), MobileNetV1()
def prepare_config(self):
self.convert_fn = False
distill_configs = [{
'loss_function': 'MSELoss',
'layers': [
......@@ -99,47 +100,105 @@ class TestImperativeDistill(unittest.TestCase):
test(self.s_model)
self.s_model.train()
distill_model = Distill(self.distill_configs, self.s_model,
self.t_model)
distill_model = Distill(
self.distill_configs,
self.s_model,
self.t_model,
convert_fn=self.convert_fn)
train(distill_model)
class TestImperativeDistillCase1(TestImperativeDistill):
def prepare_model(self):
class convbn(nn.Layer):
def __init__(self):
super(convbn, self).__init__()
self.conv = nn.Conv2D(3, 3, 3, padding=1)
self.bn = nn.BatchNorm(3)
def forward(self, x):
conv_out = self.conv(x)
bn_out = self.bn(conv_out)
return tuple([conv_out, bn_out])
class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2D(3, 3, 3, padding=1)
self.conv2 = nn.Conv2D(3, 3, 3, padding=1)
self.conv2 = convbn()
self.conv3 = nn.Conv2D(3, 3, 3, padding=1)
self.fc = nn.Linear(3072, 10)
def forward(self, x):
self.conv1_out = self.conv1(x)
conv2_out = self.conv2(self.conv1_out)
self.conv3_out = self.conv3(conv2_out)
self.conv3_out = self.conv3(conv2_out[0])
out = paddle.reshape(self.conv3_out, shape=[x.shape[0], -1])
out = paddle.nn.functional.softmax(out)
out = self.fc(out)
return out
return Model(), Model()
def prepare_config(self):
self.convert_fn = True
distill_configs = [{
'loss_function': 'MSELoss',
'layers': [
{
"layers_name": ["conv1", "conv1"]
"layers_name": ["conv1", "conv1"],
'align_params': {
'align_type': '1x1conv',
'in_channel': 3,
'out_channel': 3
}
},
{
"layers_name": ["conv2", "conv3"]
"layers_name": ["conv2", "conv3"],
'io': ["input", "output"],
'align_params': {
'align_type': '3x3conv',
'in_channel': 3,
'out_channel': 3
}
},
{
"layers_name": ["conv2", "conv3"],
'io': ["output", "output"],
'idx': [1, None],
'align_params': {
'align_type': '1x1conv+bn',
'in_channel': 3,
'out_channel': 3
}
},
{
"layers_name": ["conv2", "conv3"],
'io': ["output", "output"],
'idx': [1, None],
'align_params': {
'align_type': '3x3conv+bn',
'in_channel': 3,
'out_channel': 3,
'transpose_model': 'student'
}
},
]
}, {
'loss_function': 'CELoss',
'temperature': 1.0,
'layers': [{
"layers_name": ["fc", "fc"]
"layers_name": ["fc", "fc"],
'align_params': {
'align_type': 'linear',
'in_channel': 10,
'out_channel': 10,
'weight_init': {
'initializer': 'Normal',
'mean': 0.0,
'std': 0.02
},
}
}, ]
}]
config2yaml(distill_configs, 'test.yaml')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册