single_distiller.py 7.8 KB
Newer Older
Y
yangfukui 已提交
1 2 3 4 5
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Y
yangfukui 已提交
6
#
Y
yangfukui 已提交
7 8 9 10 11 12 13 14
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangfukui 已提交
15 16 17 18 19 20 21 22
import numpy as np
import paddle.fluid as fluid


def merge(teacher_program,
          student_program,
          data_name_map,
          place,
B
baiyfbupt 已提交
23
          scope=fluid.global_scope(),
Y
yangfukui 已提交
24 25 26 27 28 29 30 31 32 33 34
          name_prefix='teacher_'):
    """
    Merge teacher program into student program and add a uniform prefix to the
    names of all vars in teacher program
    Args:
        teacher_program(Program): The input teacher model paddle program 
        student_program(Program): The input student model paddle program
        data_map_map(dict): Describe the mapping between the teacher var name
                            and the student var name
        place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
                                                    paddle run on which device.
B
baiyfbupt 已提交
35
        scope(Scope): The input scope
Y
yangfukui 已提交
36 37
        name_prefix(str): Name prefix added for all vars of the teacher program.
    """
Y
yangfukui 已提交
38
    teacher_program = teacher_program.clone(for_test=True)
Y
yangfukui 已提交
39
    for teacher_var in teacher_program.list_vars():
Y
yangfukui 已提交
40
        skip_rename = False
Y
yangfukui 已提交
41 42 43
        if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
            if teacher_var.name in data_name_map.keys():
                new_name = data_name_map[teacher_var.name]
Y
yangfukui 已提交
44 45
                if new_name == teacher_var.name:
                    skip_rename = True
Y
yangfukui 已提交
46 47
            else:
                new_name = name_prefix + teacher_var.name
Y
yangfukui 已提交
48 49
            if not skip_rename:
                # scope var rename
B
baiyfbupt 已提交
50 51 52
                old_var = scope.var(teacher_var.name).get_tensor()
                renamed_var = scope.var(new_name).get_tensor()
                renamed_var.set(np.array(old_var), place)
53

Y
yangfukui 已提交
54 55 56
                # program var rename
                renamed_var = teacher_program.global_block()._rename_var(
                    teacher_var.name, new_name)
Y
yangfukui 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    for teacher_var in teacher_program.list_vars():
        if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
            # student program add var
            new_var = student_program.global_block()._clone_variable(
                teacher_var, force_persistable=False)
            new_var.stop_gradient = True

    for block in teacher_program.blocks:
        for op in block.ops:
            if op.type != 'feed' and op.type != 'fetch':
                inputs = {}
                outputs = {}
                attrs = {}
                for input_name in op.input_names:
                    inputs[input_name] = [
                        block.var(in_var_name)
                        for in_var_name in op.input(input_name)
                    ]

                for output_name in op.output_names:
                    outputs[output_name] = [
                        block.var(out_var_name)
                        for out_var_name in op.output(output_name)
                    ]
                for attr_name in op.attr_names:
                    attrs[attr_name] = op.attr(attr_name)
                student_program.global_block().append_op(
                    type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)


88 89 90 91 92
def fsp_loss(teacher_var1_name,
             teacher_var2_name,
             student_var1_name,
             student_var2_name,
             program=fluid.default_main_program()):
Y
yangfukui 已提交
93 94 95 96 97 98 99 100 101 102 103
    """
    Combine variables from student model and teacher model by fsp-loss.
    Args:
        teacher_var1_name(str): The name of teacher_var1.
        teacher_var2_name(str): The name of teacher_var2. Except for the
            second dimension, all other dimensions should
            be consistent with teacher_var1.
        student_var1_name(str): The name of student_var1.
        student_var2_name(str): The name of student_var2. Except for the
            second dimension, all other dimensions should
            be consistent with student_var1.
B
baiyfbupt 已提交
104 105
        program(Program): The input distiller program.
                          default: fluid.default_main_program()
Y
yangfukui 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118
    Return(Variable): fsp distiller loss.
    """
    teacher_var1 = program.global_block().var(teacher_var1_name)
    teacher_var2 = program.global_block().var(teacher_var2_name)
    student_var1 = program.global_block().var(student_var1_name)
    student_var2 = program.global_block().var(student_var2_name)
    teacher_fsp_matrix = fluid.layers.fsp_matrix(teacher_var1, teacher_var2)
    student_fsp_matrix = fluid.layers.fsp_matrix(student_var1, student_var2)
    fsp_loss = fluid.layers.reduce_mean(
        fluid.layers.square(student_fsp_matrix - teacher_fsp_matrix))
    return fsp_loss


119 120
def l2_loss(teacher_var_name,
            student_var_name,
121
            program=fluid.default_main_program()):
Y
yangfukui 已提交
122 123 124 125 126
    """
    Combine variables from student model and teacher model by l2-loss.
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
baiyfbupt 已提交
127 128
        program(Program): The input distiller program.
                          default: fluid.default_main_program() 
Y
yangfukui 已提交
129 130 131 132 133 134 135 136 137 138 139
    Return(Variable): l2 distiller loss.
    """
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    l2_loss = fluid.layers.reduce_mean(
        fluid.layers.square(student_var - teacher_var))
    return l2_loss


def soft_label_loss(teacher_var_name,
                    student_var_name,
140
                    program=fluid.default_main_program(),
Y
yangfukui 已提交
141 142 143 144 145 146 147
                    teacher_temperature=1.,
                    student_temperature=1.):
    """
    Combine variables from student model and teacher model by soft-label-loss.
    Args:
        teacher_var_name(str): The name of teacher_var.
        student_var_name(str): The name of student_var.
B
baiyfbupt 已提交
148 149
        program(Program): The input distiller program.
                          default: fluid.default_main_program() 
Y
yangfukui 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        teacher_temperature(float): Temperature used to divide
            teacher_feature_map before softmax. default: 1.0
        student_temperature(float): Temperature used to divide 
            student_feature_map before softmax. default: 1.0
    Return(Variable): l2 distiller loss.
    """
    student_var = program.global_block().var(student_var_name)
    teacher_var = program.global_block().var(teacher_var_name)
    student_var = fluid.layers.softmax(student_var / student_temperature)
    teacher_var = fluid.layers.softmax(teacher_var / teacher_temperature)
    teacher_var.stop_gradient = True
    soft_label_loss = fluid.layers.reduce_mean(
        fluid.layers.cross_entropy(
            student_var, teacher_var, soft_label=True))
    return soft_label_loss


167
def loss(loss_func, program=fluid.default_main_program(), **kwargs):
Y
yangfukui 已提交
168 169 170
    """
    Combine variables from student model and teacher model by self defined loss.
    Args:
B
baiyfbupt 已提交
171 172
        program(Program): The input distiller program.
                          default: fluid.default_main_program() 
Y
yangfukui 已提交
173 174 175 176 177 178 179 180 181 182 183 184
        loss_func(function): The user self defined loss function. 
    Return(Variable): self defined distiller loss.
    """
    func_parameters = {}
    for item in kwargs.items():
        if isinstance(item[1], str):
            func_parameters.setdefault(item[0],
                                       program.global_block().var(item[1]))
        else:
            func_parameters.setdefault(item[0], item[1])
    loss = loss_func(**func_parameters)
    return loss