single_distiller.py 14.0 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
18
import paddle.nn.functional as F
Y
yangfukui 已提交
19 20 21 22 23 24


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
25
          scope=None,
C
Chang Xu 已提交
26
          teacher_scope=None,
C
ceci3 已提交
27 28
          name_prefix='teacher_',
          merge_feed=True):
B
Bai Yifan 已提交
29
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
30
    names of all vars in teacher program
B
Bai Yifan 已提交
31

Y
yangfukui 已提交
32 33 34
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
35 36 37 38
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
39
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
40
                                                    paddle run on which device.
B
Bai Yifan 已提交
41 42 43
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
44
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
45
                          Default: 'teacher_'
C
ceci3 已提交
46
        merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
B
Bai Yifan 已提交
47 48 49

    Returns:
        None
Y
yangfukui 已提交
50
    """
51 52
    if scope == None:
        scope = paddle.static.global_scope()
C
Chang Xu 已提交
53 54
    if teacher_scope == None:
        teacher_scope = scope
Y
yangfukui 已提交
55
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
56
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
57
        skip_rename = False
C
ceci3 已提交
58 59
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
60 61
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
62 63
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
64 65
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
66 67
            if not skip_rename:
                # scope var rename
C
Chang Xu 已提交
68
                old_var = teacher_scope.var(teacher_var.name).get_tensor()
B
baiyfbupt 已提交
69 70
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
71

Y
yangfukui 已提交
72 73 74
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
75 76

    for teacher_var in teacher_program.list_vars():
C
ceci3 已提交
77 78
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
79 80 81 82 83 84 85
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
C
ceci3 已提交
86
            if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
Y
yangfukui 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)

I
itminner 已提交
106 107 108 109 110 111 112 113 114 115
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
116

C
ceci3 已提交
117 118 119 120 121
def fsp(teacher_var1_name,
        teacher_var2_name,
        student_var1_name,
        student_var2_name,
        program=None):
B
Bai Yifan 已提交
122 123
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
124 125 126 127 128 129 130 131 132
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
133 134 135 136 137
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
138
    """
139 140
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
141 142 143 144
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
145 146 147 148 149 150 151
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
152 153 154
    return fsp_loss


C
ceci3 已提交
155
def l2(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
156 157
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
158 159 160
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
161 162 163 164 165
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
166
    """
167 168
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
169 170
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
171 172
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
173 174 175
    return l2_loss


C
ceci3 已提交
176 177 178 179 180
def soft_label(teacher_var_name,
               student_var_name,
               program=None,
               teacher_temperature=1.,
               student_temperature=1.):
B
Bai Yifan 已提交
181 182
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
183 184 185
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
186 187
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
188
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
189
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
190
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
191 192 193 194
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
195
    """
196 197
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
198 199 200
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True
201 202 203 204 205 206

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
207 208 209 210 211
        paddle.nn.functional.cross_entropy(
            input=student_var,
            label=teacher_var,
            soft_label=True,
            use_softmax=False))
Y
yangfukui 已提交
212 213 214
    return soft_label_loss


B
Bai Yifan 已提交
215 216 217
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
218
    Args:
B
Bai Yifan 已提交
219 220
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
221
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
222 223 224

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
225
    """
226 227
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
228 229 230 231 232 233 234 235 236
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss
W
whs 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311


def _top_mask(x):
    top_value, top_index = paddle.topk(x, 1)
    return paddle.cast(x == top_value, "int32")


def _cal_tc_nc_pred(x, top_mask):
    """Calculate the predictions of target class and non-target class.
    The predictions of target class is a binary distribution.
    And after removing the target class, the softmax on the remaining
    parts produces the non-target predictions.
    """
    pred = paddle.nn.functional.softmax(x)
    fp_mask = paddle.cast(top_mask, "float32")
    top_value = paddle.sum(fp_mask * pred, axis=1, keepdim=True)
    tc_pred = paddle.concat([top_value, 1 - top_value], axis=1)
    tmp = paddle.assign(x)
    tmp = tmp + (-100000 * top_mask)
    nc_pred = paddle.nn.functional.softmax(tmp)
    return tc_pred, nc_pred


def _dkd_loss(student_logits,
              teacher_logits,
              temperature=1.0,
              alpha=1.0,
              beta=1.0):
    mask = _top_mask(teacher_logits)
    print(f"mask: {mask.shape}")
    print(
        f"student_logits: {student_logits.shape}; teacher_logits: {teacher_logits.shape}"
    )
    s_tc_pred, s_nc_pred = _cal_tc_nc_pred(student_logits / temperature, mask)
    t_tc_pred, t_nc_pred = _cal_tc_nc_pred(teacher_logits / temperature, mask)
    tc_loss = paddle.nn.functional.kl_div(
        s_tc_pred, t_tc_pred, reduction='mean')
    nc_loss = paddle.nn.functional.kl_div(
        s_nc_pred, t_nc_pred, reduction='mean')
    loss = alpha * tc_loss + beta * nc_loss
    return loss * temperature**2


def dkd(teacher_var_name,
        student_var_name,
        program=None,
        temperature=1.0,
        alpha=1.0,
        beta=1.0):
    """Combine variables from student model and teacher model
    by Decoupled Knowledge Distillation loss (aka. dkd-loss).
    Reference: https://github.com/megvii-research/mdistiller
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
        temperature(float): Temperature used to divide
            teacher_feature_map before softmax. Default: 1.0
        alpha(float): The weight of target class loss. Default: 1.0
        beta(float): The weight of none-target class loss. Default: 1.0

    Returns: 
        Variable: dkd distiller loss.
    """
    if program == None:
        program = paddle.static.default_main_program()
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    return _dkd_loss(
        student_var,
        teacher_var,
        temperature=temperature,
        alpha=alpha,
        beta=beta)
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361


def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
    """Combine variables from student model and teacher model 
    by Spherical Knowledge Distillation loss (aka. skd-loss).
    Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
        multiplier(float): The multiplier to recover its norm to the original 
        level. When it's None, the appropriate multiplier can be computed by 
        teacher's logits with paddle.std(output_t, axis=1). Default: None.

    Returns:
        Variable: skd distiller loss.
    """
    if program == None:
        program = paddle.static.default_main_program()

    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True

    if multiplier is None:
        multiplier = paddle.std(teacher_var, axis=1, keepdim=True)

    logits_student = F.layer_norm(
        student_var,
        student_var.shape[1:],
        weight=None,
        bias=None,
        epsilon=1e-7) * multiplier
    logits_teacher = F.layer_norm(
        teacher_var,
        teacher_var.shape[1:],
        weight=None,
        bias=None,
        epsilon=1e-7) * multiplier

    student_out = F.softmax(logits_student, axis=1)
    teacher_out = F.softmax(logits_teacher, axis=1)
    skd_loss = paddle.mean(
        F.cross_entropy(
            input=student_out,
            label=teacher_out,
            soft_label=True,
            use_softmax=False))
    return skd_loss