single_distiller.py 18.0 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
Z
zhouzj 已提交
18
import paddle.nn.functional as F
Y
yangfukui 已提交
19 20


C
ceci3 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
def _find_var_from_program(program, var_name):
    for block in program.blocks:
        if block.has_var(var_name):
            return block.var(var_name)
    raise ValueError("var {} not in this program".format(var_name))


def _except_feed_fetch(var_name, merge_feed):
    if var_name != 'fetch' and (not merge_feed or var_name != 'feed'):
        return True
    return False


def _is_same_block(block1, block2):
    if len(block1.ops) != len(block2.ops):
        return False

    for op1, op2 in zip(block1.ops, block2.ops):
        if op1.type != op2.type:
            return False

    return True


Y
yangfukui 已提交
45 46 47 48
def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
49
          scope=None,
C
Chang Xu 已提交
50
          teacher_scope=None,
C
ceci3 已提交
51 52
          name_prefix='teacher_',
          merge_feed=True):
B
Bai Yifan 已提交
53
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
54
    names of all vars in teacher program
B
Bai Yifan 已提交
55

Y
yangfukui 已提交
56 57 58
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
59 60 61 62
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
63
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
64
                                                    paddle run on which device.
B
Bai Yifan 已提交
65 66 67
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
68
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
69
                          Default: 'teacher_'
C
ceci3 已提交
70
        merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
B
Bai Yifan 已提交
71 72 73

    Returns:
        None
Y
yangfukui 已提交
74
    """
75 76
    if scope == None:
        scope = paddle.static.global_scope()
C
Chang Xu 已提交
77 78
    if teacher_scope == None:
        teacher_scope = scope
Y
yangfukui 已提交
79
    teacher_program = teacher_program.clone(for_test=True)
C
ceci3 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

    is_same_model = True
    if len(student_program.blocks) == len(teacher_program.blocks):
        for block in teacher_program.blocks:
            if not _is_same_block(block, student_program.block(block.idx)):
                is_same_model = False
                break
    else:
        is_same_model = False

    if is_same_model:
        for block in student_program.blocks:
            for op in block.ops:
                if op.type == 'while':
                    tmp_var = []
                    for _var_name in op.input('X'):
                        tmp_var.append('teacher_' + _var_name)
                    tmp_var.extend(op.input('X'))
                    op.desc.set_input("X", tmp_var)
Y
yangfukui 已提交
99 100

    for block in teacher_program.blocks:
C
ceci3 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        for teacher_var in list(block.vars.values()):
            skip_rename = False
            if _except_feed_fetch(teacher_var.name, merge_feed):
                if teacher_var.name in data_name_map.keys():
                    new_name = data_name_map[teacher_var.name]
                    if new_name == teacher_var.name:
                        skip_rename = True
                else:
                    new_name = name_prefix + teacher_var.name
                if not skip_rename:
                    # scope var rename
                    old_var = teacher_scope.var(teacher_var.name).get_tensor()
                    renamed_var = scope.var(new_name).get_tensor()
                    renamed_var.set(np.array(old_var), place)

                    # program var rename
                    renamed_var = block._rename_var(teacher_var.name, new_name)

        ### input and output of the sub_block need to rename specially.
Y
yangfukui 已提交
120
        for op in block.ops:
C
ceci3 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if _except_feed_fetch(
                            in_var_name,
                            merge_feed) and not block.has_var(in_var_name):
                        if in_var_name in data_name_map.keys():
                            new_name = data_name_map[in_var_name]
                            if new_name != in_var_name:
                                op._rename_input(in_var_name,
                                                 name_prefix + in_var_name)
                        else:
                            op._rename_input(in_var_name,
                                             name_prefix + in_var_name)

            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    if _except_feed_fetch(
                            out_var_name,
                            merge_feed) and not block.has_var(out_var_name):
                        if out_var_name in data_name_map.keys():
                            new_name = data_name_map[out_var_name]
                            if new_name != out_var_name:
                                op._rename_output(out_var_name,
                                                  name_prefix + out_var_name)
                        else:
                            op._rename_output(out_var_name,
                                              name_prefix + out_var_name)

    for block in teacher_program.blocks:
        for teacher_var in list(block.vars.values()):
            if teacher_var.name != 'fetch' and (not merge_feed or
                                                teacher_var.name != 'feed'):
                # student program add var
                if len(student_program.blocks) > 1 and is_same_model:
                    new_var = student_program.block(block.idx)._clone_variable(
                        teacher_var, force_persistable=False)
                else:
                    new_var = student_program.global_block()._clone_variable(
                        teacher_var, force_persistable=False)
                new_var.stop_gradient = True

    for block in reversed(teacher_program.blocks):
        for op_idx, op in enumerate(block.ops):
C
ceci3 已提交
164
            if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
Y
yangfukui 已提交
165 166 167 168
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
C
ceci3 已提交
169 170 171 172
                    inputs[input_name] = []
                    for in_var_name in op.input(input_name):
                        inputs[input_name].append(
                            block._find_var_recursive(in_var_name))
Y
yangfukui 已提交
173 174

                for output_name in op.output_names:
C
ceci3 已提交
175 176 177 178 179
                    outputs[output_name] = []
                    for out_var_name in op.output(output_name):
                        outputs[output_name].append(
                            block._find_var_recursive(out_var_name))

Y
yangfukui 已提交
180
                for attr_name in op.attr_names:
C
ceci3 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
                    if attr_name == 'sub_block':
                        attrs[attr_name] = student_program.block(
                            op._block_attr("sub_block").idx)
                    else:
                        attrs[attr_name] = op.attr(attr_name)
                if len(student_program.blocks) > 1 and is_same_model:
                    student_program.block(op.block.idx)._insert_op(
                        2 * op_idx,
                        type=op.type,
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs)
                else:
                    student_program.global_block().append_op(
                        type=op.type,
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs)

        student_program._sync_with_cpp()
Y
yangfukui 已提交
201

I
itminner 已提交
202 203 204 205 206 207 208 209 210 211
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
212

C
ceci3 已提交
213 214 215 216 217
def fsp(teacher_var1_name,
        teacher_var2_name,
        student_var1_name,
        student_var2_name,
        program=None):
B
Bai Yifan 已提交
218 219
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
220 221 222 223 224 225 226 227 228
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
229 230 231 232 233
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
234
    """
235 236
    if program == None:
        program = paddle.static.default_main_program()
C
ceci3 已提交
237 238 239 240
    teacher_var1 = _find_var_from_program(program, teacher_var1_name)
    teacher_var2 = _find_var_from_program(program, teacher_var2_name)
    student_var1 = _find_var_from_program(program, student_var1_name)
    student_var2 = _find_var_from_program(program, student_var2_name)
241 242 243 244 245 246 247
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
248 249 250
    return fsp_loss


C
ceci3 已提交
251
def l2(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
252 253
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
254 255 256
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
257 258 259 260 261
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
262
    """
263 264
    if program == None:
        program = paddle.static.default_main_program()
C
ceci3 已提交
265 266
    student_var = _find_var_from_program(program, student_var_name)
    teacher_var = _find_var_from_program(program, teacher_var_name)
267 268
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
269 270 271
    return l2_loss


C
ceci3 已提交
272 273 274 275 276
def soft_label(teacher_var_name,
               student_var_name,
               program=None,
               teacher_temperature=1.,
               student_temperature=1.):
B
Bai Yifan 已提交
277 278
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
279 280 281
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
282 283
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
284
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
285
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
286
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
287 288 289 290
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
291
    """
292 293
    if program == None:
        program = paddle.static.default_main_program()
C
ceci3 已提交
294 295
    student_var = _find_var_from_program(program, student_var_name)
    teacher_var = _find_var_from_program(program, teacher_var_name)
Y
yangfukui 已提交
296
    teacher_var.stop_gradient = True
297 298 299 300 301 302

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
W
whs 已提交
303
        paddle.nn.functional.cross_entropy(
Z
zhouzj 已提交
304 305 306 307
            input=student_var,
            label=teacher_var,
            soft_label=True,
            use_softmax=False))
Y
yangfukui 已提交
308 309 310
    return soft_label_loss


B
Bai Yifan 已提交
311 312 313
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
314
    Args:
B
Bai Yifan 已提交
315 316
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
317
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
318 319 320

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
321
    """
322 323
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
324 325 326 327
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
C
ceci3 已提交
328
                                       _find_var_from_program(program, item[1]))
Y
yangfukui 已提交
329 330 331 332
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss
W
whs 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399


def _top_mask(x):
    top_value, top_index = paddle.topk(x, 1)
    return paddle.cast(x == top_value, "int32")


def _cal_tc_nc_pred(x, top_mask):
    """Calculate the predictions of target class and non-target class.
    The predictions of target class is a binary distribution.
    And after removing the target class, the softmax on the remaining
    parts produces the non-target predictions.
    """
    pred = paddle.nn.functional.softmax(x)
    fp_mask = paddle.cast(top_mask, "float32")
    top_value = paddle.sum(fp_mask * pred, axis=1, keepdim=True)
    tc_pred = paddle.concat([top_value, 1 - top_value], axis=1)
    tmp = paddle.assign(x)
    tmp = tmp + (-100000 * top_mask)
    nc_pred = paddle.nn.functional.softmax(tmp)
    return tc_pred, nc_pred


def _dkd_loss(student_logits,
              teacher_logits,
              temperature=1.0,
              alpha=1.0,
              beta=1.0):
    mask = _top_mask(teacher_logits)
    print(f"mask: {mask.shape}")
    print(
        f"student_logits: {student_logits.shape}; teacher_logits: {teacher_logits.shape}"
    )
    s_tc_pred, s_nc_pred = _cal_tc_nc_pred(student_logits / temperature, mask)
    t_tc_pred, t_nc_pred = _cal_tc_nc_pred(teacher_logits / temperature, mask)
    tc_loss = paddle.nn.functional.kl_div(
        s_tc_pred, t_tc_pred, reduction='mean')
    nc_loss = paddle.nn.functional.kl_div(
        s_nc_pred, t_nc_pred, reduction='mean')
    loss = alpha * tc_loss + beta * nc_loss
    return loss * temperature**2


def dkd(teacher_var_name,
        student_var_name,
        program=None,
        temperature=1.0,
        alpha=1.0,
        beta=1.0):
    """Combine variables from student model and teacher model
    by Decoupled Knowledge Distillation loss (aka. dkd-loss).
    Reference: https://github.com/megvii-research/mdistiller
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
        temperature(float): Temperature used to divide
            teacher_feature_map before softmax. Default: 1.0
        alpha(float): The weight of target class loss. Default: 1.0
        beta(float): The weight of none-target class loss. Default: 1.0

    Returns: 
        Variable: dkd distiller loss.
    """
    if program == None:
        program = paddle.static.default_main_program()
C
ceci3 已提交
400 401
    student_var = _find_var_from_program(program, student_var_name)
    teacher_var = _find_var_from_program(program, teacher_var_name)
W
whs 已提交
402 403 404 405 406 407
    return _dkd_loss(
        student_var,
        teacher_var,
        temperature=temperature,
        alpha=alpha,
        beta=beta)
Z
zhouzj 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457


def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
    """Combine variables from student model and teacher model 
    by Spherical Knowledge Distillation loss (aka. skd-loss).
    Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
        multiplier(float): The multiplier to recover its norm to the original 
        level. When it's None, the appropriate multiplier can be computed by 
        teacher's logits with paddle.std(output_t, axis=1). Default: None.

    Returns:
        Variable: skd distiller loss.
    """
    if program == None:
        program = paddle.static.default_main_program()

    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True

    if multiplier is None:
        multiplier = paddle.std(teacher_var, axis=1, keepdim=True)

    logits_student = F.layer_norm(
        student_var,
        student_var.shape[1:],
        weight=None,
        bias=None,
        epsilon=1e-7) * multiplier
    logits_teacher = F.layer_norm(
        teacher_var,
        teacher_var.shape[1:],
        weight=None,
        bias=None,
        epsilon=1e-7) * multiplier

    student_out = F.softmax(logits_student, axis=1)
    teacher_out = F.softmax(logits_teacher, axis=1)
    skd_loss = paddle.mean(
        F.cross_entropy(
            input=student_out,
            label=teacher_out,
            soft_label=True,
            use_softmax=False))
    return skd_loss