single_distiller.py 12.1 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
Y
yangfukui 已提交
18 19 20 21 22 23


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
24
          scope=None,
C
Chang Xu 已提交
25
          teacher_scope=None,
C
ceci3 已提交
26 27
          name_prefix='teacher_',
          merge_feed=True):
B
Bai Yifan 已提交
28
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
29
    names of all vars in teacher program
B
Bai Yifan 已提交
30

Y
yangfukui 已提交
31 32 33
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
34 35 36 37
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
38
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
39
                                                    paddle run on which device.
B
Bai Yifan 已提交
40 41 42
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
43
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
44
                          Default: 'teacher_'
C
ceci3 已提交
45
        merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
B
Bai Yifan 已提交
46 47 48

    Returns:
        None
Y
yangfukui 已提交
49
    """
50 51
    if scope == None:
        scope = paddle.static.global_scope()
C
Chang Xu 已提交
52 53
    if teacher_scope == None:
        teacher_scope = scope
Y
yangfukui 已提交
54
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
55
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
56
        skip_rename = False
C
ceci3 已提交
57 58
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
59 60
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
61 62
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
63 64
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
65 66
            if not skip_rename:
                # scope var rename
C
Chang Xu 已提交
67
                old_var = teacher_scope.var(teacher_var.name).get_tensor()
B
baiyfbupt 已提交
68 69
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
70

Y
yangfukui 已提交
71 72 73
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
74 75

    for teacher_var in teacher_program.list_vars():
C
ceci3 已提交
76 77
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
78 79 80 81 82 83 84
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
C
ceci3 已提交
85
            if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
Y
yangfukui 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)

I
itminner 已提交
105 106 107 108 109 110 111 112 113 114
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
115

C
ceci3 已提交
116 117 118 119 120
def fsp(teacher_var1_name,
        teacher_var2_name,
        student_var1_name,
        student_var2_name,
        program=None):
B
Bai Yifan 已提交
121 122
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
123 124 125 126 127 128 129 130 131
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
132 133 134 135 136
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
137
    """
138 139
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
140 141 142 143
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
144 145 146 147 148 149 150
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
151 152 153
    return fsp_loss


C
ceci3 已提交
154
def l2(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
155 156
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
157 158 159
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
160 161 162 163 164
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
165
    """
166 167
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
168 169
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
170 171
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
172 173 174
    return l2_loss


C
ceci3 已提交
175 176 177 178 179
def soft_label(teacher_var_name,
               student_var_name,
               program=None,
               teacher_temperature=1.,
               student_temperature=1.):
B
Bai Yifan 已提交
180 181
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
182 183 184
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
185 186
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
187
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
188
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
189
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
190 191 192 193
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
194
    """
195 196
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
197 198 199
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True
200 201 202 203 204 205

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
W
whs 已提交
206 207
        paddle.nn.functional.cross_entropy(
            input=student_var, label=teacher_var, soft_label=True))
Y
yangfukui 已提交
208 209 210
    return soft_label_loss


B
Bai Yifan 已提交
211 212 213
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
214
    Args:
B
Bai Yifan 已提交
215 216
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
217
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
218 219 220

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
221
    """
222 223
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
224 225 226 227 228 229 230 231 232
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss
W
whs 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307


def _top_mask(x):
    top_value, top_index = paddle.topk(x, 1)
    return paddle.cast(x == top_value, "int32")


def _cal_tc_nc_pred(x, top_mask):
    """Calculate the predictions of target class and non-target class.
    The predictions of target class is a binary distribution.
    And after removing the target class, the softmax on the remaining
    parts produces the non-target predictions.
    """
    pred = paddle.nn.functional.softmax(x)
    fp_mask = paddle.cast(top_mask, "float32")
    top_value = paddle.sum(fp_mask * pred, axis=1, keepdim=True)
    tc_pred = paddle.concat([top_value, 1 - top_value], axis=1)
    tmp = paddle.assign(x)
    tmp = tmp + (-100000 * top_mask)
    nc_pred = paddle.nn.functional.softmax(tmp)
    return tc_pred, nc_pred


def _dkd_loss(student_logits,
              teacher_logits,
              temperature=1.0,
              alpha=1.0,
              beta=1.0):
    mask = _top_mask(teacher_logits)
    print(f"mask: {mask.shape}")
    print(
        f"student_logits: {student_logits.shape}; teacher_logits: {teacher_logits.shape}"
    )
    s_tc_pred, s_nc_pred = _cal_tc_nc_pred(student_logits / temperature, mask)
    t_tc_pred, t_nc_pred = _cal_tc_nc_pred(teacher_logits / temperature, mask)
    tc_loss = paddle.nn.functional.kl_div(
        s_tc_pred, t_tc_pred, reduction='mean')
    nc_loss = paddle.nn.functional.kl_div(
        s_nc_pred, t_nc_pred, reduction='mean')
    loss = alpha * tc_loss + beta * nc_loss
    return loss * temperature**2


def dkd(teacher_var_name,
        student_var_name,
        program=None,
        temperature=1.0,
        alpha=1.0,
        beta=1.0):
    """Combine variables from student model and teacher model
    by Decoupled Knowledge Distillation loss (aka. dkd-loss).
    Reference: https://github.com/megvii-research/mdistiller
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
        temperature(float): Temperature used to divide
            teacher_feature_map before softmax. Default: 1.0
        alpha(float): The weight of target class loss. Default: 1.0
        beta(float): The weight of none-target class loss. Default: 1.0

    Returns: 
        Variable: dkd distiller loss.
    """
    if program == None:
        program = paddle.static.default_main_program()
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    return _dkd_loss(
        student_var,
        teacher_var,
        temperature=temperature,
        alpha=alpha,
        beta=beta)