single_distiller.py 9.1 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
Y
yangfukui 已提交
18 19 20 21 22 23


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
24
          scope=None,
Y
yangfukui 已提交
25
          name_prefix='teacher_'):
B
Bai Yifan 已提交
26
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
27
    names of all vars in teacher program
B
Bai Yifan 已提交
28

Y
yangfukui 已提交
29 30 31
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
32 33 34 35
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
36
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
37
                                                    paddle run on which device.
B
Bai Yifan 已提交
38 39 40
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
41
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
42 43 44 45
                          Default: 'teacher_'

    Returns:
        None
Y
yangfukui 已提交
46
    """
47 48
    if scope == None:
        scope = paddle.static.global_scope()
Y
yangfukui 已提交
49
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
50
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
51
        skip_rename = False
Y
yangfukui 已提交
52 53 54
        if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
55 56
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
57 58
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
59 60
            if not skip_rename:
                # scope var rename
B
baiyfbupt 已提交
61 62 63
                old_var = scope.var(teacher_var.name).get_tensor()
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
64

Y
yangfukui 已提交
65 66 67
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

    for teacher_var in teacher_program.list_vars():
        if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
            if op.type != 'feed' and op.type != 'fetch':
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)

I
itminner 已提交
98 99 100 101 102 103 104 105 106 107
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
108

109 110 111 112
def fsp_loss(teacher_var1_name,
             teacher_var2_name,
             student_var1_name,
             student_var2_name,
B
Bai Yifan 已提交
113 114 115
             program=None):
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
116 117 118 119 120 121 122 123 124
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
125 126 127 128 129
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
130
    """
131 132
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
133 134 135 136
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
137 138 139 140 141 142 143
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
144 145 146
    return fsp_loss


147
def l2_loss(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
148 149
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
150 151 152
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
153 154 155 156 157
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
158
    """
159 160
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
161 162
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
163 164
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
165 166 167 168 169
    return l2_loss


def soft_label_loss(teacher_var_name,
                    student_var_name,
B
Bai Yifan 已提交
170
                    program=None,
Y
yangfukui 已提交
171 172
                    teacher_temperature=1.,
                    student_temperature=1.):
B
Bai Yifan 已提交
173 174
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
175 176 177
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
178 179
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
180
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
181
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
182
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
183 184 185 186
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
187
    """
188 189
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
190 191 192
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True
193 194 195 196 197 198 199

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
        paddle.fluid.layers.cross_entropy(
Y
yangfukui 已提交
200 201 202 203
            student_var, teacher_var, soft_label=True))
    return soft_label_loss


B
Bai Yifan 已提交
204 205 206
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
207
    Args:
B
Bai Yifan 已提交
208 209
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
210
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
211 212 213

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
214
    """
215 216
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
217 218 219 220 221 222 223 224 225
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss