From a35619b829e7b5ddfe726115af382f590fa2c47d Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Tue, 27 Dec 2022 14:18:30 +0800 Subject: [PATCH] add skd distillation. (#1587) * add skd distillation. * update skd's test. --- paddleslim/dist/__init__.py | 2 +- paddleslim/dist/single_distiller.py | 56 +++++++++++++++++++- tests/test_skd_loss.py | 81 +++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 tests/test_skd_loss.py diff --git a/paddleslim/dist/__init__.py b/paddleslim/dist/__init__.py index de4b6196..46a02564 100755 --- a/paddleslim/dist/__init__.py +++ b/paddleslim/dist/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .single_distiller import merge, fsp, l2, soft_label, loss, dkd +from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd from .dml import DML diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index fba5bb08..8b20706a 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -15,6 +15,7 @@ import numpy as np import paddle from paddleslim.core import GraphWrapper +import paddle.nn.functional as F def _find_var_from_program(program, var_name): @@ -300,7 +301,10 @@ def soft_label(teacher_var_name, teacher_temperature) soft_label_loss = paddle.mean( paddle.nn.functional.cross_entropy( - input=student_var, label=teacher_var, soft_label=True)) + input=student_var, + label=teacher_var, + soft_label=True, + use_softmax=False)) return soft_label_loss @@ -401,3 +405,53 @@ def dkd(teacher_var_name, temperature=temperature, alpha=alpha, beta=beta) + + +def skd(teacher_var_name, student_var_name, program=None, multiplier=None): + """Combine variables from student model and teacher model + by Spherical Knowledge Distillation loss (aka. skd-loss). + Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None + multiplier(float): The multiplier to recover its norm to the original + level. When it's None, the appropriate multiplier can be computed by + teacher's logits with paddle.std(output_t, axis=1). Default: None. + + Returns: + Variable: skd distiller loss. + """ + if program == None: + program = paddle.static.default_main_program() + + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + teacher_var.stop_gradient = True + + if multiplier is None: + multiplier = paddle.std(teacher_var, axis=1, keepdim=True) + + logits_student = F.layer_norm( + student_var, + student_var.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * multiplier + logits_teacher = F.layer_norm( + teacher_var, + teacher_var.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * multiplier + + student_out = F.softmax(logits_student, axis=1) + teacher_out = F.softmax(logits_teacher, axis=1) + skd_loss = paddle.mean( + F.cross_entropy( + input=student_out, + label=teacher_out, + soft_label=True, + use_softmax=False)) + return skd_loss diff --git a/tests/test_skd_loss.py b/tests/test_skd_loss.py new file mode 100644 index 00000000..19a07b34 --- /dev/null +++ b/tests/test_skd_loss.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle +from paddleslim.dist import merge, skd +from layers import conv_bn_layer +from static_case import StaticCase + + +class TestSKDLoss(StaticCase): + def test_skd_loss(self): + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + student_program = paddle.static.Program() + student_startup = paddle.static.Program() + with paddle.static.program_guard(student_program, student_startup): + with paddle.utils.unique_name.guard(): + input = paddle.static.data( + name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + student_predict = conv1 + conv2 + + teacher_program = paddle.static.Program() + teacher_startup = paddle.static.Program() + with paddle.static.program_guard(teacher_program, teacher_startup): + with paddle.utils.unique_name.guard(): + input = paddle.static.data( + name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6") + + exe.run(teacher_startup) + exe.run(student_startup) + + data_name_map = {'image': 'image'} + merge(teacher_program, student_program, data_name_map, place) + merged_ops = [] + for block in student_program.blocks: + for op in block.ops: + merged_ops.append(op.type) + with paddle.static.program_guard(student_program, student_startup): + distill_loss = skd('teacher_' + teacher_predict.name, + student_predict.name, + program=None, + multiplier=None) + + loss_ops = [] + for block in student_program.blocks: + for op in block.ops: + loss_ops.append(op.type) + print(f"ret: {set(loss_ops).difference(set(merged_ops))}") + self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set()) + + self.assertTrue({ + 'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm' + }.issubset(set(loss_ops).difference(set(merged_ops)))) + + +if __name__ == '__main__': + unittest.main() -- GitLab