single_distiller.py 9.4 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
Y
yangfukui 已提交
18 19 20 21 22 23


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
24
          scope=None,
C
Chang Xu 已提交
25
          teacher_scope=None,
C
ceci3 已提交
26 27
          name_prefix='teacher_',
          merge_feed=True):
B
Bai Yifan 已提交
28
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
29
    names of all vars in teacher program
B
Bai Yifan 已提交
30

Y
yangfukui 已提交
31 32 33
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
34 35 36 37
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
38
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
39
                                                    paddle run on which device.
B
Bai Yifan 已提交
40 41 42
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
43
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
44
                          Default: 'teacher_'
C
ceci3 已提交
45
        merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
B
Bai Yifan 已提交
46 47 48

    Returns:
        None
Y
yangfukui 已提交
49
    """
50 51
    if scope == None:
        scope = paddle.static.global_scope()
C
Chang Xu 已提交
52 53
    if teacher_scope == None:
        teacher_scope = scope
Y
yangfukui 已提交
54
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
55
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
56
        skip_rename = False
C
ceci3 已提交
57
        if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'):
Y
yangfukui 已提交
58 59
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
60 61
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
62 63
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
64 65
            if not skip_rename:
                # scope var rename
C
Chang Xu 已提交
66
                old_var = teacher_scope.var(teacher_var.name).get_tensor()
B
baiyfbupt 已提交
67 68
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
69

Y
yangfukui 已提交
70 71 72
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
73 74

    for teacher_var in teacher_program.list_vars():
C
ceci3 已提交
75
        if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'):
Y
yangfukui 已提交
76 77 78 79 80 81 82
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
C
ceci3 已提交
83
            if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
Y
yangfukui 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)

I
itminner 已提交
103 104 105 106 107 108 109 110 111 112
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
113

114 115 116 117
def fsp_loss(teacher_var1_name,
             teacher_var2_name,
             student_var1_name,
             student_var2_name,
B
Bai Yifan 已提交
118 119 120
             program=None):
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
121 122 123 124 125 126 127 128 129
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
130 131 132 133 134
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
135
    """
136 137
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
138 139 140 141
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
142 143 144 145 146 147 148
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
149 150 151
    return fsp_loss


152
def l2_loss(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
153 154
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
155 156 157
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
158 159 160 161 162
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
163
    """
164 165
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
166 167
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
168 169
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
170 171 172 173 174
    return l2_loss


def soft_label_loss(teacher_var_name,
                    student_var_name,
B
Bai Yifan 已提交
175
                    program=None,
Y
yangfukui 已提交
176 177
                    teacher_temperature=1.,
                    student_temperature=1.):
B
Bai Yifan 已提交
178 179
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
180 181 182
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
183 184
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
185
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
186
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
187
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
188 189 190 191
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
192
    """
193 194
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
195 196 197
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True
198 199 200 201 202 203 204

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
        paddle.fluid.layers.cross_entropy(
Y
yangfukui 已提交
205 206 207 208
            student_var, teacher_var, soft_label=True))
    return soft_label_loss


B
Bai Yifan 已提交
209 210 211
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
212
    Args:
B
Bai Yifan 已提交
213 214
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
215
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
216 217 218

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
219
    """
220 221
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
222 223 224 225 226 227 228 229 230
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss