single_distiller.py 9.4 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15
import numpy as np
16
import paddle
I
itminner 已提交
17
from paddleslim.core import GraphWrapper
Y
yangfukui 已提交
18 19 20 21 22 23


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
Bai Yifan 已提交
24
          scope=None,
C
Chang Xu 已提交
25
          teacher_scope=None,
C
ceci3 已提交
26 27
          name_prefix='teacher_',
          merge_feed=True):
B
Bai Yifan 已提交
28
    """Merge teacher program into student program and add a uniform prefix to the
Y
yangfukui 已提交
29
    names of all vars in teacher program
B
Bai Yifan 已提交
30

Y
yangfukui 已提交
31 32 33
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
B
Bai Yifan 已提交
34 35 36 37
        data_map_map(dict): Mapping of teacher input interface name and student
                            input interface name, where key of dict is the
                            input name of teacher_program, and value is the
                            input name of student_program.
38
        place(CPUPlace()|CUDAPlace(N)): This parameter represents
Y
yangfukui 已提交
39
                                                    paddle run on which device.
B
Bai Yifan 已提交
40 41 42
        scope(Scope): This parameter indicates the variable scope used by
                      the program. If not specified, the default global scope
                      will be used. Default: None
Y
yangfukui 已提交
43
        name_prefix(str): Name prefix added for all vars of the teacher program.
B
Bai Yifan 已提交
44
                          Default: 'teacher_'
C
ceci3 已提交
45
        merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
B
Bai Yifan 已提交
46 47 48

    Returns:
        None
Y
yangfukui 已提交
49
    """
50 51
    if scope == None:
        scope = paddle.static.global_scope()
C
Chang Xu 已提交
52 53
    if teacher_scope == None:
        teacher_scope = scope
Y
yangfukui 已提交
54
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
55
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
56
        skip_rename = False
C
ceci3 已提交
57 58
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
59 60
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
61 62
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
63 64
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
65 66
            if not skip_rename:
                # scope var rename
C
Chang Xu 已提交
67
                old_var = teacher_scope.var(teacher_var.name).get_tensor()
B
baiyfbupt 已提交
68 69
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
70

Y
yangfukui 已提交
71 72 73
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
74 75

    for teacher_var in teacher_program.list_vars():
C
ceci3 已提交
76 77
        if teacher_var.name != 'fetch' and (not merge_feed or
                                            teacher_var.name != 'feed'):
Y
yangfukui 已提交
78 79 80 81 82 83 84
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
C
ceci3 已提交
85
            if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
Y
yangfukui 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)

I
itminner 已提交
105 106 107 108 109 110 111 112 113 114
    student_graph = GraphWrapper(student_program)
    for op in student_graph.ops():
        belongsto_teacher = False
        for inp in op.all_inputs():
            if 'teacher' in inp.name():
                belongsto_teacher = True
                break
        if belongsto_teacher:
            op._op._set_attr("skip_quant", True)

Y
yangfukui 已提交
115

C
ceci3 已提交
116 117 118 119 120
def fsp(teacher_var1_name,
        teacher_var2_name,
        student_var1_name,
        student_var2_name,
        program=None):
B
Bai Yifan 已提交
121 122
    """Combine variables from student model and teacher model by fsp-loss.

Y
yangfukui 已提交
123 124 125 126 127 128 129 130 131
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
Bai Yifan 已提交
132 133 134 135 136
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns:
        Variable: fsp distiller loss.
Y
yangfukui 已提交
137
    """
138 139
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
140 141 142 143
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
144 145 146 147 148 149 150
    teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
                                                        teacher_var2)
    student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
                                                        student_var2)
    fsp_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_fsp_matrix,
                                               teacher_fsp_matrix))
Y
yangfukui 已提交
151 152 153
    return fsp_loss


C
ceci3 已提交
154
def l2(teacher_var_name, student_var_name, program=None):
B
Bai Yifan 已提交
155 156
    """Combine variables from student model and teacher model by l2-loss.

Y
yangfukui 已提交
157 158 159
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
160 161 162 163 164
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None

    Returns: 
        Variable: l2 distiller loss.
Y
yangfukui 已提交
165
    """
166 167
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
168 169
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
170 171
    l2_loss = paddle.mean(
        paddle.nn.functional.square_error_cost(student_var, teacher_var))
Y
yangfukui 已提交
172 173 174
    return l2_loss


C
ceci3 已提交
175 176 177 178 179
def soft_label(teacher_var_name,
               student_var_name,
               program=None,
               teacher_temperature=1.,
               student_temperature=1.):
B
Bai Yifan 已提交
180 181
    """Combine variables from student model and teacher model by soft-label-loss.

Y
yangfukui 已提交
182 183 184
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
Bai Yifan 已提交
185 186
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
187
        teacher_temperature(float): Temperature used to divide
B
Bai Yifan 已提交
188
            teacher_feature_map before softmax. Default: 1.0
Y
yangfukui 已提交
189
        student_temperature(float): Temperature used to divide 
B
Bai Yifan 已提交
190 191 192 193
            student_feature_map before softmax. Default: 1.0

    Returns:
        Variable: l2 distiller loss.
Y
yangfukui 已提交
194
    """
195 196
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
197 198 199
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    teacher_var.stop_gradient = True
200 201 202 203 204 205 206

    student_var = paddle.nn.functional.softmax(student_var /
                                               student_temperature)
    teacher_var = paddle.nn.functional.softmax(teacher_var /
                                               teacher_temperature)
    soft_label_loss = paddle.mean(
        paddle.fluid.layers.cross_entropy(
Y
yangfukui 已提交
207 208 209 210
            student_var, teacher_var, soft_label=True))
    return soft_label_loss


B
Bai Yifan 已提交
211 212 213
def loss(loss_func, program=None, **kwargs):
    """Combine variables from student model and teacher model by self defined loss.

Y
yangfukui 已提交
214
    Args:
B
Bai Yifan 已提交
215 216
        program(Program): The input distiller program. If not specified,
                          the default program will be used. Default: None
Y
yangfukui 已提交
217
        loss_func(function): The user self defined loss function. 
B
Bai Yifan 已提交
218 219 220

    Returns: 
        Variable: self defined distiller loss.
Y
yangfukui 已提交
221
    """
222 223
    if program == None:
        program = paddle.static.default_main_program()
Y
yangfukui 已提交
224 225 226 227 228 229 230 231 232
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss