diff --git a/paddleslim/dist/__init__.py b/paddleslim/dist/__init__.py index 23c97ad8b37900fb9cc899e1886460d5dbfd7d7b..de4b6196a4b25892e644ac51a28a95dc12fc4aed 100755 --- a/paddleslim/dist/__init__.py +++ b/paddleslim/dist/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .single_distiller import merge, fsp, l2, soft_label, loss +from .single_distiller import merge, fsp, l2, soft_label, loss, dkd from .dml import DML diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 1300a2fe4690d159d7c7433ae4726e112f073635..8a658a6ae2ec162da811f25e6062b0927fbff27c 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -230,3 +230,78 @@ def loss(loss_func, program=None, **kwargs): func_parameters.setdefault(item[0], item[1]) loss = loss_func(**func_parameters) return loss + + +def _top_mask(x): + top_value, top_index = paddle.topk(x, 1) + return paddle.cast(x == top_value, "int32") + + +def _cal_tc_nc_pred(x, top_mask): + """Calculate the predictions of target class and non-target class. + The predictions of target class is a binary distribution. + And after removing the target class, the softmax on the remaining + parts produces the non-target predictions. + """ + pred = paddle.nn.functional.softmax(x) + fp_mask = paddle.cast(top_mask, "float32") + top_value = paddle.sum(fp_mask * pred, axis=1, keepdim=True) + tc_pred = paddle.concat([top_value, 1 - top_value], axis=1) + tmp = paddle.assign(x) + tmp = tmp + (-100000 * top_mask) + nc_pred = paddle.nn.functional.softmax(tmp) + return tc_pred, nc_pred + + +def _dkd_loss(student_logits, + teacher_logits, + temperature=1.0, + alpha=1.0, + beta=1.0): + mask = _top_mask(teacher_logits) + print(f"mask: {mask.shape}") + print( + f"student_logits: {student_logits.shape}; teacher_logits: {teacher_logits.shape}" + ) + s_tc_pred, s_nc_pred = _cal_tc_nc_pred(student_logits / temperature, mask) + t_tc_pred, t_nc_pred = _cal_tc_nc_pred(teacher_logits / temperature, mask) + tc_loss = paddle.nn.functional.kl_div( + s_tc_pred, t_tc_pred, reduction='mean') + nc_loss = paddle.nn.functional.kl_div( + s_nc_pred, t_nc_pred, reduction='mean') + loss = alpha * tc_loss + beta * nc_loss + return loss * temperature**2 + + +def dkd(teacher_var_name, + student_var_name, + program=None, + temperature=1.0, + alpha=1.0, + beta=1.0): + """Combine variables from student model and teacher model + by Decoupled Knowledge Distillation loss (aka. dkd-loss). + Reference: https://github.com/megvii-research/mdistiller + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None + temperature(float): Temperature used to divide + teacher_feature_map before softmax. Default: 1.0 + alpha(float): The weight of target class loss. Default: 1.0 + beta(float): The weight of none-target class loss. Default: 1.0 + + Returns: + Variable: dkd distiller loss. + """ + if program == None: + program = paddle.static.default_main_program() + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + return _dkd_loss( + student_var, + teacher_var, + temperature=temperature, + alpha=alpha, + beta=beta) diff --git a/tests/test_dkd_loss.py b/tests/test_dkd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ab836f8ff53185964416cde830dfd9115fbeb13b --- /dev/null +++ b/tests/test_dkd_loss.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle +from paddleslim.dist import merge, dkd +from layers import conv_bn_layer +from static_case import StaticCase + + +class TestDKDLoss(StaticCase): + def test_dkd_loss(self): + input = paddle.static.data(name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + student_predict = conv1 + conv2 + student_predict = paddle.fluid.layers.fc(student_predict, size=10) + + teacher_main = paddle.static.Program() + teacher_startup = paddle.static.Program() + with paddle.static.program_guard(teacher_main, teacher_startup): + input = paddle.static.data(name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6") + teacher_predict = paddle.fluid.layers.fc(teacher_predict, size=10) + + place = paddle.CPUPlace() + data_name_map = {'image': 'image'} + merge(teacher_main, + paddle.static.default_main_program(), data_name_map, place) + + merged_ops = [] + for block in paddle.static.default_main_program().blocks: + for op in block.ops: + merged_ops.append(op.type) + + distill_loss = dkd("teacher_" + (teacher_predict.name), + student_predict.name) + loss_ops = [] + for block in paddle.static.default_main_program().blocks: + for op in block.ops: + loss_ops.append(op.type) + self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set()) + self.assertTrue( + set(loss_ops).difference(set(merged_ops)) == { + 'kldiv_loss', 'assign', 'scale', 'concat', 'reduce_sum', + 'equal', 'softmax', 'reduce_mean', 'cast', 'elementwise_mul', + 'top_k_v2' + }) + + +if __name__ == '__main__': + unittest.main()