From 3fde095bc2a6f302968389302b662f4634e2f114 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 6 Dec 2021 11:48:17 +0800 Subject: [PATCH] [distill] support wrap functional with class (#897) --- demo/dygraph/dist/bert/README.md | 6 +- demo/dygraph/dist/bert/distill_stage1.yaml | 33 +++- demo/dygraph/dist/bert/distill_stage2.yaml | 4 +- demo/dygraph/dist/bert/run.sh | 2 + demo/dygraph/dist/bert/task_distill.py | 28 ++- paddleslim/common/__init__.py | 3 + paddleslim/common/wrapper_function.py | 130 ++++++++++++ paddleslim/dygraph/dist/distill.py | 186 +++++++++++++----- paddleslim/dygraph/dist/distill_helpers.py | 2 +- paddleslim/dygraph/dist/losses/__init__.py | 4 +- .../dygraph/dist/losses/distillation_loss.py | 113 ++++++++++- tests/dygraph/test_distill.py | 73 ++++++- 12 files changed, 489 insertions(+), 95 deletions(-) create mode 100644 paddleslim/common/wrapper_function.py diff --git a/demo/dygraph/dist/bert/README.md b/demo/dygraph/dist/bert/README.md index d8f72b37..3d73a2dd 100644 --- a/demo/dygraph/dist/bert/README.md +++ b/demo/dygraph/dist/bert/README.md @@ -10,19 +10,19 @@ TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型`tinybert-6l-768d-v2`进行初始化。

-
+
TinyBERT蒸馏流程图

在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。 -知识的蒸馏通常是通过让学生模型学习相关的蒸馏相损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。 +知识的蒸馏通常是通过让学生模型学习相关的蒸馏损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。 由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。 实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。 -需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2`、`tinybert-4l-312d-v2`这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d`、`tinybert-4l-312d`、`tinybert-6l-768d-zh`、`tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。 +需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2`、`tinybert-4l-312d-v2`这两个v2版本的预训练模型中从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d`、`tinybert-4l-312d`、`tinybert-6l-768d-zh`、`tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。 ### 安装PaddleNLP和Paddle 本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。 diff --git a/demo/dygraph/dist/bert/distill_stage1.yaml b/demo/dygraph/dist/bert/distill_stage1.yaml index 1899721a..6ec0640a 100644 --- a/demo/dygraph/dist/bert/distill_stage1.yaml +++ b/demo/dygraph/dist/bert/distill_stage1.yaml @@ -1,10 +1,14 @@ - DistillConfig: - loss_function: MSELoss + - loss_function: MSELoss model_name_pairs: - - student_0 - teacher_0 weight: 1.0 - - layers: + align_params: + align_type: linear + in_channel: 768 + out_channel: 768 + layers: - layers_name: ['tinybert.embeddings', 'bert.embeddings'] - layers_name: ['tinybert.encoder.layers.0', 'bert.encoder.layers.1'] - layers_name: ['tinybert.encoder.layers.1', 'bert.encoder.layers.3'] @@ -12,9 +16,22 @@ - layers_name: ['tinybert.encoder.layers.3', 'bert.encoder.layers.7'] - layers_name: ['tinybert.encoder.layers.4', 'bert.encoder.layers.9'] - layers_name: ['tinybert.encoder.layers.5', 'bert.encoder.layers.11'] - - layers_name: ['tinybert.encoder.layers.0.self_attn', 'bert.encoder.layers.1.self_attn'] - - layers_name: ['tinybert.encoder.layers.1.self_attn', 'bert.encoder.layers.3.self_attn'] - - layers_name: ['tinybert.encoder.layers.2.self_attn', 'bert.encoder.layers.5.self_attn'] - - layers_name: ['tinybert.encoder.layers.3.self_attn', 'bert.encoder.layers.7.self_attn'] - - layers_name: ['tinybert.encoder.layers.4.self_attn', 'bert.encoder.layers.9.self_attn'] - - layers_name: ['tinybert.encoder.layers.5.self_attn', 'bert.encoder.layers.11.self_attn'] + + - loss_function: MSELoss + model_name_pairs: + - - student_0 + - teacher_0 + weight: 1.0 + layers: + - layers_name: ['tinybert.encoder.layers.0.self_attn.wrap_fn_softmax_0', 'bert.encoder.layers.1.self_attn.wrap_fn_softmax_2'] + io: ['input', 'input'] + - layers_name: ['tinybert.encoder.layers.1.self_attn.wrap_fn_softmax_2', 'bert.encoder.layers.3.self_attn.wrap_fn_softmax_6'] + io: ['input', 'input'] + - layers_name: ['tinybert.encoder.layers.2.self_attn.wrap_fn_softmax_4', 'bert.encoder.layers.5.self_attn.wrap_fn_softmax_10'] + io: ['input', 'input'] + - layers_name: ['tinybert.encoder.layers.3.self_attn.wrap_fn_softmax_6', 'bert.encoder.layers.7.self_attn.wrap_fn_softmax_14'] + io: ['input', 'input'] + - layers_name: ['tinybert.encoder.layers.4.self_attn.wrap_fn_softmax_8', 'bert.encoder.layers.9.self_attn.wrap_fn_softmax_18'] + io: ['input', 'input'] + - layers_name: ['tinybert.encoder.layers.5.self_attn.wrap_fn_softmax_10', 'bert.encoder.layers.11.self_attn.wrap_fn_softmax_22'] + io: ['input', 'input'] diff --git a/demo/dygraph/dist/bert/distill_stage2.yaml b/demo/dygraph/dist/bert/distill_stage2.yaml index 6d448a78..51009316 100644 --- a/demo/dygraph/dist/bert/distill_stage2.yaml +++ b/demo/dygraph/dist/bert/distill_stage2.yaml @@ -1,9 +1,9 @@ - DistillConfig: - loss_function: CELoss + - loss_function: CELoss model_name_pairs: - - student_0 - teacher_0 weight: 1.0 - - layers: + layers: - layers_name: ['classifier', 'classifier'] temperature: 1.0 diff --git a/demo/dygraph/dist/bert/run.sh b/demo/dygraph/dist/bert/run.sh index 58e07166..e1ab62ba 100644 --- a/demo/dygraph/dist/bert/run.sh +++ b/demo/dygraph/dist/bert/run.sh @@ -17,4 +17,6 @@ python3.7 task_distill.py \ --logging_steps 10 \ --save_steps 10 \ --output_dir ./tmp/$TASK_NAME/ \ + --distill_config ./distill_stage1.yaml \ --device gpu + diff --git a/demo/dygraph/dist/bert/task_distill.py b/demo/dygraph/dist/bert/task_distill.py index b76d453e..ce85a6be 100644 --- a/demo/dygraph/dist/bert/task_distill.py +++ b/demo/dygraph/dist/bert/task_distill.py @@ -157,11 +157,6 @@ def parse_args(): "--use_aug", action="store_true", help="Whether to use augmentation data to train.", ) - parser.add_argument( - "--intermediate_distill", - action="store_true", - help="Whether distilling intermediate layers. If False, it means prediction layer distillation.", - ) parser.add_argument( "--weight_decay", default=0.0, @@ -349,7 +344,6 @@ def do_train(args): teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type] teacher = teacher_model_class.from_pretrained( args.teacher_path, num_classes=num_classes) - teacher.eval() if paddle.distributed.get_world_size() > 1: student = paddle.DataParallel(student, find_unused_parameters=True) @@ -368,10 +362,20 @@ def do_train(args): lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate, num_training_steps, warmup) + ### step1: load distill config + assert os.path.exists( + args.distill_config), "distill file {} not exist.".format( + args.distill_config) + ### step2: wrap the student model and teacher model by paddleslim.dygraph.dist.Distill + ### the distill config need to be passed into it. + distill_model = Distill( + args.distill_config, students=[student], teachers=[teacher]) + + ### step3: add parameter created by align op to optimizer # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ - p.name for n, p in student.named_parameters() + p.name for n, p in distill_model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( @@ -379,7 +383,7 @@ def do_train(args): beta1=0.9, beta2=0.999, epsilon=args.adam_epsilon, - parameters=student.parameters(), + parameters=distill_model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params) @@ -390,17 +394,11 @@ def do_train(args): tic_train = time.time() best_res = 0.0 - assert os.path.exists( - args.distill_config), "distill file {} not exist.".format( - args.distill_config) - distill_model = Distill( - args.distill_config, student_models=[student], - teacher_models=[teacher]) - for epoch in range(num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, segment_ids, labels = batch + ### step4: call distill_model instead of call student model and teacher model independently. loss, _, _ = distill_model(input_ids, segment_ids) loss.backward() diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 2e9e660e..0c1b226b 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -22,9 +22,12 @@ from .server import Server from .client import Client from .meter import AvgrageMeter from .analyze_helper import VarCollector +from paddleslim.common import wrapper_function __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', 'Server', 'Client', 'RLBaseController', 'VarCollector' ] + +__all__ += wrapper_function.__all__ diff --git a/paddleslim/common/wrapper_function.py b/paddleslim/common/wrapper_function.py new file mode 100644 index 00000000..f19b35d1 --- /dev/null +++ b/paddleslim/common/wrapper_function.py @@ -0,0 +1,130 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ['Counter', 'init_index', 'functional2layer'] + + +class Counter: + """ + limit the number of function calls. + """ + + def __init__(self, times=1): + self.calls = 0 + self.times = times + + def __call__(self, func): + def counter_wrapper(*args, **kwargs): + func(*args, **kwargs) + self.calls += 1 + assert self.calls <= self.times, "function {} only allow to call {} times.".format( + func.__name__, self.times) + + return counter_wrapper + + +class FuncWrapper(nn.Layer): + """ + """ + + def __init__(self, functional): + super(FuncWrapper, self).__init__() + self.fn = functional + + def forward(self, *x, **kwargs): + return self.fn(*x, **kwargs) + + +def convert_fn(fn): + def new_fn(*x, **kwargs): + global global_idx + model = inspect.currentframe().f_back.f_locals['self'] + ### TODO(ceci3): length of sublayers is 0 means not call a nn.Layer in __init__ function. + ### this condition maybe not rigorous, need to change it later. + ### model.training set to False is to avoid only eval student model. + if len(model.sublayers()) == 0 or model.training == False: + result = eval('F.origin_{}'.format(fn.__name__))(*x, **kwargs) + return result + else: + if getattr(model, 'wrap_fn_{}_{}'.format(fn.__name__, global_idx), + None) == None: + setattr(model, 'wrap_fn_{}_{}'.format(fn.__name__, global_idx), + FuncWrapper(fn)) + result = getattr(model, 'wrap_fn_{}_{}'.format( + fn.__name__, global_idx))(*x, **kwargs) + global_idx += 1 + return result + + return new_fn + + +def init_index(): + global global_idx + global_idx = 0 + + +@Counter() +def functional2layer(): + """ + Wrap the function in paddle.nn.functional with class inherited from paddle.nn.Layer. + The purpose of this operation is to get the output of paddle.nn.functional in the model. + For example: + ```python + class Model(nn.Layer): + def __init__(self): + self.fc = nn.Linear(12, 16) + def forward(x): + softmax_out = nn.functional.softmax(x) + fc_out = self.fc(softmax_out) + relu_out = nn.functional.relu(fc_out) + return relu_out + ``` + Before call the ```paddleslim.common.functional2layer``` function, we can get the output + of this model and the output of self.fc function through ```register_forward_post_hook``` in the paddle. + And the ```register_forward_post_hook``` interface can only used to get the output of class which is + inherited from paddle.nn.Layer. Please reference to: + [register_forward_post_hook](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Layer_cn.html#register_forward_post_hook) + After call the ```paddleslim.common.functional2layer```, the model above will be converted to + a new model: + ```python + class Model(nn.Layer): + def __init__(self): + self.fc = nn.Linear(12, 16) + self.wrap_fn_softmax_0 = FuncWrapper() + self.wrap_fn_relu_1 = FuncWrapper() + def forward(x): + softmax_out = self.wrap_fn_softmax_0(x) + fc_out = self.fc(softmax_out) + relu_out = self.wrap_fn_relu_1(fc_out) + return relu_out + ``` + after this convert operation, we can get the output of softmax through ```register_forward_post_hook```. + The convert operation can applies to the layers in paddle.nn.functional.py. + """ + init_index() + not_convert = ['linear', 'conv1d', 'conv1d_transpose', \ + 'conv2d', 'conv2d_transpose', 'conv3d', \ + 'conv3d_transpose', 'one_hot', 'embedding'] + for f in dir(F): + if not f.startswith('__') and f not in not_convert and not f.startswith( + 'origin_'): + setattr(F, 'origin_{}'.format(f), eval('F.{}'.format(f))) + if inspect.isfunction(eval('F.{}'.format(f))): + new_fn = convert_fn(eval('F.{}'.format(f))) + setattr(F, '{}'.format(f), new_fn) diff --git a/paddleslim/dygraph/dist/distill.py b/paddleslim/dygraph/dist/distill.py index d067cd63..f9294747 100644 --- a/paddleslim/dygraph/dist/distill.py +++ b/paddleslim/dygraph/dist/distill.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import numpy as np +import copy import collections -from collections import namedtuple +import numpy as np +import paddle import paddle.nn as nn +from ...common.wrapper_function import init_index, functional2layer from . import losses from .losses.basic_loss import BASIC_LOSS from .distill_helpers import yaml2config @@ -30,6 +32,8 @@ class LayerConfig: model_name_pairs, layers_name, loss_function, + io=["output", "output"], + idx=[None, None], weight=1.0, temperature=1.0, align_params=None, @@ -42,6 +46,8 @@ class LayerConfig: loss_function, BASIC_LOSS.module_dict.keys())) self.loss_function = loss_function + self.io = io + self.idx = idx self.weight = weight self.temperature = temperature self.align_params = align_params @@ -49,7 +55,7 @@ class LayerConfig: setattr(self, k, v) -def _add_hooks(model, outs, hook_layers_name): +def _add_hooks(model, outs, layers_name, hook_layers_name, io='o', idx="None"): """ Get output by layer name. models(nn.Layer): model need to be add hook. @@ -57,75 +63,89 @@ def _add_hooks(model, outs, hook_layers_name): hook_layers_name(list): name of middle layers. """ - def _get_activation(outs, name): - ### TODO: need to support get input tensor - #outs[name] = {} + def _get_activation(outs, name, io, idx): def get_output_hook(layer, input, output): - #outs[name]["output"] = output - #outs[name]["input"] = input - outs[name] = output + if io == 'o': + if idx == "None": + outs[name] = output + else: + outs[name] = output[idx] + else: + if idx == "None": + outs[name] = input + else: + outs[name] = input[idx] return get_output_hook ### TODO: support DP model - for idx, (n, m) in enumerate(model.named_sublayers()): - if n in hook_layers_name: - m.register_forward_post_hook(_get_activation(outs, n)) + for i, (n, m) in enumerate(model.named_sublayers()): + if n == layers_name: + hooks = m.register_forward_post_hook( + _get_activation(outs, hook_layers_name, io, idx)) + return hooks + + +def _remove_hooks(hooks): + for hook in hooks: + hook.remove() class Distill(nn.Layer): """ Distill API. - distill_configs(list(dict) | path): the list of distill config. - student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode. - teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode. - return_model_outputs(bool): whether to return model output. Default: True. + configs(list(dict) | string): the list of distill config or the path of yaml file which contain the distill config. + students(list(nn.Layer)): the list of student model, the state of student model must be training mode. + teachers(list(nn.Layer)): the list of teacher model. + convert_fn(bool): convert the functional in paddlepaddle to nn.Layer. The detail of this convert operation please + reference to ```paddleslim.common.functional2layer```. Default: True. + return_model_outputs(bool): whether to return the origin outputs of the model. If set to True, will return distill loss, the output of students and the output of teachers, the output of each part will be returned as a list. Default: True. """ def __init__(self, - distill_configs, - student_models, - teacher_models, + configs, + students, + teachers, + convert_fn=True, return_model_outputs=True): super(Distill, self).__init__() - if isinstance(student_models, nn.Layer): - student_models = [student_models] - if isinstance(teacher_models, nn.Layer): - teacher_models = [teacher_models] - for student_model in student_models: - assert student_model.training, "The student model should not be eval mode." - for teacher_model in teacher_models: - assert teacher_model.training is False, "The teacher model should be eval mode." - - if isinstance(distill_configs, list): - self._distill_configs = distill_configs - elif os.path.exists(distill_configs): - if distill_configs.endswith(".yaml"): - self._distill_configs = yaml2config(distill_configs) + if convert_fn: + functional2layer() + if isinstance(students, nn.Layer): + students = [students] + if isinstance(teachers, nn.Layer): + teachers = [teachers] + + if isinstance(configs, list): + self._configs = configs + elif os.path.exists(configs): + if configs.endswith(".yaml"): + self._configs = yaml2config(configs) else: raise NotImplementedError("distill config file type error!") else: raise NotImplementedError("distill config error!") - self._student_models = student_models - self._teacher_models = teacher_models + self._student_models = nn.LayerList(students) + self._teacher_models = nn.LayerList(teachers) self._return_model_outputs = return_model_outputs self._loss_config_list = [] - for c in self._distill_configs: - self._transpose_config(c) + for c in self._configs: + unfold_layer_config = self._transpose_config(c) + self._loss_config_list.extend(unfold_layer_config) - self._hook_layers = self._extract_hook_position() + hook_layers = self._extract_hook_position() + self._hook_layers = hook_layers # use self._loss_config_list to create all loss object self.distill_loss = losses.CombinedLoss(self._loss_config_list) - self._output_tensor_dict = self._prepare_outputs() + self._output_tensor_dict = self._prepare_outputs(hook_layers) + self._check_hook_output = False def parameters(self): - params = [] - for s_model in self._student_models: - params.extend(s_model.parameters()) - return params + return self._student_models.parameters() + self.distill_loss.parameters( + ) def _extract_hook_position(self): """ extrat hook position according to config""" @@ -145,6 +165,7 @@ class Distill(nn.Layer): def _transpose_config(self, config): """ Transpose config to loss needed """ + unfold_config = [] global_config = {} if 'model_name_pairs' not in config: global_config['model_name_pairs'] = [['student_0', 'teacher_0']] @@ -159,57 +180,114 @@ class Distill(nn.Layer): global_config[key] = config[key] for per_layer_config in config['layers']: - per_layer_config.update(global_config) - self._loss_config_list.append( - LayerConfig(**per_layer_config).__dict__) + per_layer_config.update(copy.deepcopy(global_config)) + layer_config = LayerConfig(**per_layer_config).__dict__ + for idx in range(len(layer_config['layers_name'])): + ### slice 0 from string "input" or "output", results is "i" or "o". + postfix = '#' + layer_config['io'][idx][0] + '#' + str( + layer_config['idx'][idx]) + layer_config['layers_name'][idx] += postfix + ### io and idx only use to extract tensor from hook, so pop it here. + layer_config.pop('io') + layer_config.pop('idx') + unfold_config.append(layer_config) + return unfold_config - def _prepare_outputs(self): + def _prepare_outputs(self, hook_layers, in_forward=False): """ Add hook to get the output tensor of target layer. """ outputs_tensor = {} for idx, m in enumerate(self._student_models): - hook_layers = self._hook_layers['student_{}'.format(idx)] + tmp_hook_layers = hook_layers['student_{}'.format(idx)] stu_outs = collections.OrderedDict() outputs_tensor['student_{}'.format(idx)] = self._prepare_hook( - m, hook_layers, stu_outs) + m, tmp_hook_layers, stu_outs, in_forward=in_forward) for idx, m in enumerate(self._teacher_models): - hook_layers = self._hook_layers['teacher_{}'.format(idx)] + tmp_hook_layers = hook_layers['teacher_{}'.format(idx)] tea_outs = collections.OrderedDict() outputs_tensor['teacher_{}'.format(idx)] = self._prepare_hook( - m, hook_layers, tea_outs) + m, tmp_hook_layers, tea_outs, in_forward=in_forward) return outputs_tensor - def _prepare_hook(self, model, hook_layers, outs_dict): + def _prepare_hook(self, model, hook_layers, outs_dict, in_forward): """ Add hook. """ + self.forward_hooks = [] for layer in hook_layers: - if isinstance(layer, str): - _add_hooks(model, outs_dict, layer) + tmp = layer.strip().split('#') + layer_name, io, idx = tmp[0], tmp[1], tmp[2] + if idx != "None": + idx = int(idx) + if in_forward: + if 'wrap_fn_' in layer_name: + hooks = _add_hooks(model, outs_dict, layer_name, layer, io, + idx) + self.forward_hooks.append(hooks) + else: + if 'wrap_fn_' not in layer_name: + _add_hooks(model, outs_dict, layer_name, layer, io, idx) return outs_dict + def _useless_forward(self, *inputs, **kwargs): + for idx, student_model in enumerate(self._student_models): + ### initialize global index before each forward + init_index() + student_model.forward(*inputs, **kwargs) + for idx, teacher_model in enumerate(self._teacher_models): + ### initialize global index before each forward + init_index() + teacher_model.forward(*inputs, **kwargs) + def forward(self, *inputs, **kwargs): + if self._check_hook_output is False: + ### the first useless forward is to convert function to class. + self._useless_forward(*inputs, **kwargs) + + update_output_tensor_dict = self._prepare_outputs( + self._hook_layers, in_forward=True) + students_batch_outs = [] teachers_batch_outs = [] for idx, student_model in enumerate(self._student_models): + ### initialize global index before each forward + init_index() stu_batch_outs = student_model.forward(*inputs, **kwargs) students_batch_outs.append(stu_batch_outs) for idx, teacher_model in enumerate(self._teacher_models): + ### initialize global index before each forward + init_index() tea_batch_outs = teacher_model.forward(*inputs, **kwargs) if not teacher_model.training: tea_batch_outs = [i.detach() for i in tea_batch_outs] teachers_batch_outs.extend(tea_batch_outs) + ### update hook information. + for model, _ in self._output_tensor_dict.items(): + self._output_tensor_dict[model].update(update_output_tensor_dict[ + model]) + if len(self._student_models) == 1: students_batch_outs = students_batch_outs[0] if len(self._teacher_models) == 1: teachers_batch_outs = teachers_batch_outs[0] + if self._check_hook_output is False: + self._check_hook_output = True + for mo, hook_out in self._output_tensor_dict.items(): + for hook_name, hook_value in hook_out.items(): + hook_name = hook_name.strip().split('#')[0] + assert type(hook_value) is paddle.Tensor or len(\ + hook_value) == 1, \ + "model: {} layer: {} has more than one output/input" \ + ", please specific the idx of output/input.".format(mo, hook_name) ### batch is None just for now distill_outputs = self.distill_loss(self._output_tensor_dict, None) distill_loss = distill_outputs['loss'] + _remove_hooks(self.forward_hooks) + if self._return_model_outputs: return distill_loss, students_batch_outs, teachers_batch_outs else: diff --git a/paddleslim/dygraph/dist/distill_helpers.py b/paddleslim/dygraph/dist/distill_helpers.py index 07c3532e..71b4ed2c 100644 --- a/paddleslim/dygraph/dist/distill_helpers.py +++ b/paddleslim/dygraph/dist/distill_helpers.py @@ -13,7 +13,7 @@ # limitations under the License. import yaml -__all__ = ['config2yaml'] +__all__ = ['config2yaml', 'yaml2config'] def yaml2config(yaml_path): diff --git a/paddleslim/dygraph/dist/losses/__init__.py b/paddleslim/dygraph/dist/losses/__init__.py index d583a275..37547aaf 100644 --- a/paddleslim/dygraph/dist/losses/__init__.py +++ b/paddleslim/dygraph/dist/losses/__init__.py @@ -38,9 +38,9 @@ class CombinedLoss(nn.Layer): """ def __init__(self, loss_config_list=None): - super().__init__() + super(CombinedLoss, self).__init__() loss_config_list = copy.deepcopy(loss_config_list) - self.loss_func = [] + self.loss_func = nn.LayerList() self.loss_weight = [] assert isinstance(loss_config_list, list), ( 'operator config should be a list') diff --git a/paddleslim/dygraph/dist/losses/distillation_loss.py b/paddleslim/dygraph/dist/losses/distillation_loss.py index 87abce11..38f37141 100644 --- a/paddleslim/dygraph/dist/losses/distillation_loss.py +++ b/paddleslim/dygraph/dist/losses/distillation_loss.py @@ -12,12 +12,13 @@ #See the License for the specific language governing permissions and #limitations under the License. +import numpy as np import paddle import paddle.nn as nn from .basic_loss import BASIC_LOSS -__all__ = ["DistillationLoss"] +__all__ = ["DistillationLoss", "ShapeAlign"] class DistillationLoss(nn.Layer): @@ -44,8 +45,12 @@ class DistillationLoss(nn.Layer): self.align_params = params.pop( 'align_params') if 'align_params' in params else None if self.align_params is not None: - for attr, value in self.align_params.items(): - setattr(self, attr, value) + if 'transpose_model' in self.align_params: + self.transpose_model = self.align_params['transpose_model'] + self.align_params.pop('transpose_model') + else: + self.transpose_model = 'student' + self.align_func = ShapeAlign(**self.align_params) self.loss_func = BASIC_LOSS.get(loss_function)(**params) @@ -59,6 +64,11 @@ class DistillationLoss(nn.Layer): ) == 2, "length of layers_name must be equal to 2." out1 = out1[self.layers_name[0]] out2 = out2[self.layers_name[1]] + if self.align_params is not None: + if self.transpose_model == 'student': + out1 = self.align_func(out1) + else: + out2 = self.align_func(out2) if self.temperature != 1.0: out1 = out1 / self.temperature out2 = out2 / self.temperature @@ -66,3 +76,100 @@ class DistillationLoss(nn.Layer): 1], self.layers_name[0] if self.layers_name != None else "0", \ self.layers_name[1] if self.layers_name != None else "0")] = self.loss_func(out1, out2) return loss_dict + + +class ShapeAlign(nn.Layer): + """ + Align the feature map between student and teacher. + Args: + align_type(str): reshape tensor by which op, choice in ['1x1conv','3x3conv','1x1conv+bn','3x3conv+bn','linear'] + in_channel(int): input channel number + out_channel(int): output channel number + """ + + def __init__(self, align_type, in_channel, out_channel, weight_init=None): + super(ShapeAlign, self).__init__() + self._in_channel = in_channel + self._out_channel = out_channel + assert align_type.lower() in [ + '1x1conv', '3x3conv', '1x1conv+bn', '3x3conv+bn', 'linear' + ], "only support 1x1conv, 3x3conv, 1x1conv+bn, 3x3conv+bn, linear for now" + + bias_attr = None + if weight_init is not None: + assert 'initializer' in weight_init + init_mode = weight_init.pop('initializer') + ### load transpose weight from pretrained model. + if init_mode == 'Assign': + bias = None + assert 'params_path' in weight_init + assert 'params_name' in weight_init + params_path = weight_init['params_path'] + params_name = weight_init['params_name'] + if isinstance(weight_init['params_name'], (list, tuple)): + assert len(weight_init['params_name']) <= 2 + weight = paddle.load(params_path)[params_name[0]] + bias = paddle.load(params_path)[params_name[1]] + else: + weight = paddle.load(params_path)[params_name] + weight_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(weight)) + if bias is not None: + bias_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(bias)) + else: + weight_attr = paddle.framework.ParamAttr(initializer=eval( + 'paddle.nn.initializer.{}'.format(init_mode))( + **weight_init)) + else: + weight_attr = None + if align_type.lower() == '1x1conv': + self.align_op = paddle.nn.Conv2D( + in_channel, + out_channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=weight_attr, + bias_attr=bias_attr) + elif align_type.lower() == '3x3conv': + self.align_op = paddle.nn.Conv2D( + in_channel, + out_channel, + kernel_size=3, + stride=1, + padding=1, + weight_attr=weight_attr, + bias_attr=bias_attr) + elif align_type.lower() == '1x1conv+bn': + self.align_op = paddle.nn.Sequential( + paddle.nn.Conv2D( + in_channel, + out_channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=weight_attr, + bias_attr=bias_attr), + paddle.nn.BatchNorm2D(out_channel)) + elif align_type.lower() == '3x3conv+bn': + self.align_op = paddle.nn.Sequential( + paddle.nn.Conv2D( + in_channel, + out_channel, + kernel_size=3, + stride=1, + padding=1, + weight_attr=weight_attr, + bias_attr=bias_attr), + paddle.nn.BatchNorm2D(out_channel)) + elif align_type.lower() == 'linear': + self.align_op = paddle.nn.Linear( + in_channel, + out_channel, + weight_attr=weight_attr, + bias_attr=bias_attr) + + def forward(self, feat): + out = self.align_op(feat) + return out diff --git a/tests/dygraph/test_distill.py b/tests/dygraph/test_distill.py index d3ebafaf..c5303871 100644 --- a/tests/dygraph/test_distill.py +++ b/tests/dygraph/test_distill.py @@ -24,6 +24,7 @@ class TestImperativeDistill(unittest.TestCase): return MobileNetV1(), MobileNetV1() def prepare_config(self): + self.convert_fn = False distill_configs = [{ 'loss_function': 'MSELoss', 'layers': [ @@ -99,47 +100,105 @@ class TestImperativeDistill(unittest.TestCase): test(self.s_model) self.s_model.train() - distill_model = Distill(self.distill_configs, self.s_model, - self.t_model) + distill_model = Distill( + self.distill_configs, + self.s_model, + self.t_model, + convert_fn=self.convert_fn) train(distill_model) class TestImperativeDistillCase1(TestImperativeDistill): def prepare_model(self): + class convbn(nn.Layer): + def __init__(self): + super(convbn, self).__init__() + self.conv = nn.Conv2D(3, 3, 3, padding=1) + self.bn = nn.BatchNorm(3) + + def forward(self, x): + conv_out = self.conv(x) + bn_out = self.bn(conv_out) + return tuple([conv_out, bn_out]) + class Model(nn.Layer): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2D(3, 3, 3, padding=1) - self.conv2 = nn.Conv2D(3, 3, 3, padding=1) + self.conv2 = convbn() self.conv3 = nn.Conv2D(3, 3, 3, padding=1) self.fc = nn.Linear(3072, 10) def forward(self, x): self.conv1_out = self.conv1(x) conv2_out = self.conv2(self.conv1_out) - self.conv3_out = self.conv3(conv2_out) + self.conv3_out = self.conv3(conv2_out[0]) out = paddle.reshape(self.conv3_out, shape=[x.shape[0], -1]) + out = paddle.nn.functional.softmax(out) out = self.fc(out) return out return Model(), Model() def prepare_config(self): + self.convert_fn = True distill_configs = [{ 'loss_function': 'MSELoss', 'layers': [ { - "layers_name": ["conv1", "conv1"] + "layers_name": ["conv1", "conv1"], + 'align_params': { + 'align_type': '1x1conv', + 'in_channel': 3, + 'out_channel': 3 + } }, { - "layers_name": ["conv2", "conv3"] + "layers_name": ["conv2", "conv3"], + 'io': ["input", "output"], + 'align_params': { + 'align_type': '3x3conv', + 'in_channel': 3, + 'out_channel': 3 + } + }, + { + "layers_name": ["conv2", "conv3"], + 'io': ["output", "output"], + 'idx': [1, None], + 'align_params': { + 'align_type': '1x1conv+bn', + 'in_channel': 3, + 'out_channel': 3 + } + }, + { + "layers_name": ["conv2", "conv3"], + 'io': ["output", "output"], + 'idx': [1, None], + 'align_params': { + 'align_type': '3x3conv+bn', + 'in_channel': 3, + 'out_channel': 3, + 'transpose_model': 'student' + } }, ] }, { 'loss_function': 'CELoss', 'temperature': 1.0, 'layers': [{ - "layers_name": ["fc", "fc"] + "layers_name": ["fc", "fc"], + 'align_params': { + 'align_type': 'linear', + 'in_channel': 10, + 'out_channel': 10, + 'weight_init': { + 'initializer': 'Normal', + 'mean': 0.0, + 'std': 0.02 + }, + } }, ] }] config2yaml(distill_configs, 'test.yaml') -- GitLab