distill.py 8.6 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import collections
from collections import namedtuple
import paddle.nn as nn
C
cc 已提交
19
from . import losses
C
ceci3 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

__all__ = ['Distill', 'AdaptorBase']


class LayerConfig:
    def __init__(self,
                 s_feature_idx,
                 t_feature_idx,
                 feature_type,
                 loss_function,
                 weight=1.0,
                 align=False,
                 align_shape=None):
        self.s_feature_idx = s_feature_idx
        self.t_feature_idx = t_feature_idx
        self.feature_type = feature_type
        if loss_function in ['l1', 'l2', 'smooth_l1']:
            self.loss_function = 'DistillationDistanceLoss'
        elif loss_function in ['dml']:
            self.loss_function = 'DistillationDMLLoss'
        elif loss_function in ['rkl']:
            self.loss_function = 'DistillationRKDLoss'
C
cc 已提交
42 43
        elif hasattr(losses, loss_function):
            self.loss_function = loss_function
C
ceci3 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        else:
            raise NotImplementedError("loss function is not support!!!")
        self.weight = weight
        self.align = align
        self.align_shape = align_shape


class AdaptorBase:
    def __init__(self, model):
        self.model = model
        self.add_tensor = False

    def _get_activation(self, outs, name):
        def get_output_hook(layer, input, output):
            outs[name] = output

        return get_output_hook

    def _add_distill_hook(self, outs, mapping_layers_name, layers_type):
        """
C
cc 已提交
64
            Get output by layer name.
C
ceci3 已提交
65 66 67 68
            outs(dict): save the middle outputs of model according to the name.
            mapping_layers(list): name of middle layers.
            layers_type(list): type of the middle layers to calculate distill loss.
        """
C
cc 已提交
69

C
ceci3 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        ### TODO: support DP model
        for idx, (n, m) in enumerate(self.model.named_sublayers()):
            if n in mapping_layers_name:
                midx = mapping_layers_name.index(n)
                m.register_forward_post_hook(
                    self._get_activation(outs, layers_type[midx]))

    def mapping_layers(self):
        raise NotImplementedError("function mapping_layers is not implemented")


class Distill(nn.Layer):
    ### TODO: support list of student model and teacher model
    def __init__(self, distill_configs, student_models, teacher_models,
                 adaptors_S, adaptors_T):
        super(Distill, self).__init__()
C
cc 已提交
86 87
        assert student_models.training, "The student model should be eval mode."

C
ceci3 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100
        self._distill_configs = distill_configs
        self._student_models = student_models
        self._teacher_models = teacher_models
        self._adaptors_S = adaptors_S(self._student_models)
        self._adaptors_T = adaptors_T(self._teacher_models)

        self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs()

        self.configs = []
        for c in self._distill_configs:
            self.configs.append(LayerConfig(**c).__dict__)

        self.distill_idx = self._get_distill_idx()
C
cc 已提交
101

C
ceci3 已提交
102 103 104 105 106 107 108 109 110 111 112 113
        self._loss_config_list = []
        for c in self.configs:
            loss_config = {}
            loss_config[str(c['loss_function'])] = {}
            loss_config[str(c['loss_function'])]['weight'] = c['weight']
            loss_config[str(c['loss_function'])]['key'] = c[
                'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[
                    't_feature_idx'])
            ### TODO: support list of student models and teacher_models
            loss_config[str(c['loss_function'])][
                'model_name_pairs'] = [['student', 'teacher']]
            self._loss_config_list.append(loss_config)
C
cc 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

        # use self._loss_config_list to create all loss object
        self.distill_loss = losses.CombinedLoss(self._loss_config_list)

    def _prepare_outputs(self):
        """
        Add hook to get the output tensor of target layer.
        Returns:
            stu_outs_dict(dict): the name and tensor for the student model,
                such as {'hidden_0': tensor_0, ..}
            tea_outs_dict(dict): the name and tensor for the teather model,
                such as {'hidden_0': tensor_0, ..}    
        """
        stu_outs_dict = collections.OrderedDict()
        tea_outs_dict = collections.OrderedDict()
        stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
        tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
        return stu_outs_dict, tea_outs_dict
C
ceci3 已提交
132 133

    def _prepare_hook(self, adaptors, outs_dict):
C
cc 已提交
134 135 136
        """
        Add hook.
        """
C
ceci3 已提交
137 138 139 140 141 142 143
        mapping_layers = adaptors.mapping_layers()
        for layer_type, layer in mapping_layers.items():
            if isinstance(layer, str):
                adaptors._add_distill_hook(outs_dict, [layer], [layer_type])
        return outs_dict

    def _get_distill_idx(self):
C
cc 已提交
144 145 146 147 148 149
        """
        For each feature_type, get the feature index in the student and teacher models.
        Returns:
            distill_idx(dict): the feature index for each feature_type,
                such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
        """
C
ceci3 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        distill_idx = {}
        for config in self._distill_configs:
            if config['feature_type'] not in distill_idx:
                distill_idx[config['feature_type']] = [[
                    int(config['s_feature_idx']), int(config['t_feature_idx'])
                ]]
            else:
                distill_idx[config['feature_type']].append([
                    int(config['s_feature_idx']), int(config['t_feature_idx'])
                ])
        return distill_idx

    def forward(self, *inputs, **kwargs):
        stu_batch_outs = self._student_models.forward(*inputs, **kwargs)
        tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs)
C
cc 已提交
165 166 167 168
        if not self._teacher_models.training:
            tea_batch_outs = [i.detach() for i in tea_batch_outs]

        # get all target tensor
C
ceci3 已提交
169 170 171 172 173 174 175 176
        if self._adaptors_S.add_tensor == False:
            self._adaptors_S.add_tensor = True
        if self._adaptors_T.add_tensor == False:
            self._adaptors_T.add_tensor = True
        self.stu_outs_dict = self._get_model_intermediate_output(
            self._adaptors_S, self.stu_outs_dict)
        self.tea_outs_dict = self._get_model_intermediate_output(
            self._adaptors_T, self.tea_outs_dict)
C
cc 已提交
177 178 179

        distill_inputs = self._process_outputs()

C
ceci3 已提交
180 181 182
        ### batch is None just for now
        distill_outputs = self.distill_loss(distill_inputs, None)
        distill_loss = distill_outputs['loss']
C
cc 已提交
183

C
ceci3 已提交
184
        return stu_batch_outs, tea_batch_outs, distill_loss
C
cc 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223

    def _get_model_intermediate_output(self, adaptors, outs_dict):
        """
        Use the adaptor get the target tensor.
        Returns:
            outs_dict(dict): the name and tensor for the target model,
                such as {'hidden_0': tensor_0, ..}
        """
        mapping_layers = adaptors.mapping_layers()
        for layer_type, layer in mapping_layers.items():
            if isinstance(layer, str):
                continue
            outs_dict[layer_type] = layer
        return outs_dict

    def _process_outputs(self):
        """
        Process the target tensor to adapt for loss.
        """
        ### TODO: support list of student models and teacher_models
        final_distill_dict = {
            "student": collections.OrderedDict(),
            "teacher": collections.OrderedDict()
        }

        for feature_type, dist_idx in self.distill_idx.items():
            for idx, idx_list in enumerate(dist_idx):
                sidx, tidx = idx_list[0], idx_list[1]
                stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)]
                tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)]
                if not self._student_models.training:
                    stu_out = stu_out.detach()
                if not self._teacher_models.training:
                    tea_out = tea_out.detach()

                name_str = feature_type + '_' + str(sidx) + '_' + str(tidx)
                final_distill_dict['student'][name_str] = stu_out
                final_distill_dict['teacher'][name_str] = tea_out
        return final_distill_dict