diff --git a/paddleslim/dygraph/dist/losses/__init__.py b/paddleslim/dygraph/dist/losses/__init__.py index 8f5c9599552a7b4ab9bf70954f3d29b52efafcee..e61e35274337b47d3e2ab0ecc298b541310666b9 100644 --- a/paddleslim/dygraph/dist/losses/__init__.py +++ b/paddleslim/dygraph/dist/losses/__init__.py @@ -11,3 +11,73 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import copy +import paddle +import paddle.nn as nn + +from . import basic_loss +from . import distillation_loss + +from .basic_loss import L1Loss +from .basic_loss import L2Loss +from .basic_loss import SmoothL1Loss +from .basic_loss import CELoss +from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss +from .basic_loss import RKdAngle, RkdDistance + +from .distillation_loss import DistillationDistanceLoss +from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationRKDLoss + + +class CombinedLoss(nn.Layer): + """ + CombinedLoss: a combination of loss function. + Args: + loss_config_list: a config list used to build loss function. A demo is as follows, + which is used to calculate dml loss between Student output and + Teacher output. Parameter weight is needed for the loss weight. + - DistillationDMLLoss: + weight: 1.0 + act: "softmax" + model_name_pairs: + - ["Student", "Teacher"] + """ + + def __init__(self, loss_config_list=None): + super().__init__() + loss_config_list = copy.deepcopy(loss_config_list) + self.loss_func = [] + self.loss_weight = [] + assert isinstance(loss_config_list, list), ( + 'operator config should be a list') + supported_loss_list = basic_loss.__all__ + distillation_loss.__all__ + for config in loss_config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + name = list(config)[0] + assert name in supported_loss_list, \ + "loss name must be in {} but got: {}".format(name, supported_loss_list) + param = config[name] + assert "weight" in param, "weight must be in param, but param just contains {}".format( + param.keys()) + self.loss_weight.append(param.pop("weight")) + self.loss_func.append(eval(name)(**param)) + + def forward(self, input, batch, **kargs): + loss_dict = {} + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch, **kargs) + weight = self.loss_weight[idx] + if isinstance(loss, paddle.Tensor): + loss = {"loss_{}_{}".format(str(loss), idx): loss * weight} + else: + loss = { + "{}_{}".format(key, idx): loss[key] * weight + for key in loss + } + loss_dict.update(loss) + loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + return loss_dict diff --git a/paddleslim/dygraph/dist/losses/basic_loss.py b/paddleslim/dygraph/dist/losses/basic_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f0684d8a07d849ac788735a3fc4bd025b3276d --- /dev/null +++ b/paddleslim/dygraph/dist/losses/basic_loss.py @@ -0,0 +1,207 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + +__all__ = [ + "CELoss", + "DMLLoss", + "DistanceLoss", + "RKdAngle", + "RkdDistance", +] + + +class CELoss(nn.Layer): + """ + CELoss: cross entropy loss + Args: + epsilon(float | None): label smooth epsilon. If it is None or not in range (0, 1), + then label smooth will not be used. + label_act(string | None): activation function, it works when the label is also the logits + rather than the groundtruth label. + axis(int): axis used to calculate cross entropy loss. + + """ + + def __init__(self, epsilon=None, label_act="softmax", axis=-1): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + assert label_act in ["softmax", None] + if epsilon is not None and (epsilon >= 1 or epsilon <= 0): + epsilon = None + self.epsilon = epsilon + self.label_act = label_act + self.axis = axis + + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def forward(self, x, label): + assert len(x.shape) == len(label.shape), \ + "x and label shape length should be same but got {} for x.shape and {} for label.shape".format(x.shape, label.shape) + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=self.axis) + loss = paddle.sum(x * label, axis=self.axis) + else: + if label.shape[self.axis] == x.shape[self.axis]: + if self.label_act == "softmax": + label = F.softmax(label, axis=self.axis) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy( + x, label=label, soft_label=soft_label, axis=self.axis) + loss = loss.mean() + return loss + + +class DMLLoss(nn.Layer): + """ + DMLLoss + Args: + act(string | None): activation function used to activate the input tensor + axis(int): axis used to build activation function + """ + + def __init__(self, act=None, axis=-1): + super().__init__() + if act is not None: + assert act in ["softmax", "sigmoid"] + if act == "softmax": + self.act = nn.Softmax(axis=axis) + elif act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = None + + def forward(self, out1, out2): + if self.act is not None: + out1 = self.act(out1) + out2 = self.act(out2) + + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = (F.kl_div( + log_out1, out2, reduction='batchmean') + F.kl_div( + log_out2, out1, reduction='batchmean')) / 2.0 + return loss + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss + Args: + mode: loss mode + kargs(dict): used to build corresponding loss function, for more details, please + refer to: + L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss + L2Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss + SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss + """ + + def __init__(self, mode="l2", **kargs): + super().__init__() + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l2": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + + def forward(self, x, y): + return self.loss_func(x, y) + + +def pdist(e, squared=False, eps=1e-12): + e_square = e.pow(2).sum(axis=1) + prod = paddle.mm(e, e.t()) + res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clip( + min=eps) + + if not squared: + res = res.sqrt() + + return res + + +class RKdAngle(nn.Layer): + """ + RKdAngle loss, see https://arxiv.org/abs/1904.05068 + """ + + def __init__(self): + super().__init__() + + def forward(self, student, teacher): + # reshape for feature map distillation + bs = student.shape[0] + student = student.reshape([bs, -1]) + teacher = teacher.reshape([bs, -1]) + + td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) + norm_td = F.normalize(td, p=2, axis=2) + t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape( + [-1, 1]) + + sd = (student.unsqueeze(0) - student.unsqueeze(1)) + norm_sd = F.normalize(sd, p=2, axis=2) + s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape( + [-1, 1]) + loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean') + return loss + + +class RkdDistance(nn.Layer): + """ + RkdDistance loss, see https://arxiv.org/abs/1904.05068 + Args: + eps(float): epsilon for the pdist function + """ + + def __init__(self, eps=1e-12): + super().__init__() + self.eps = eps + + def forward(self, student, teacher): + bs = student.shape[0] + student = student.reshape([bs, -1]) + teacher = teacher.reshape([bs, -1]) + + t_d = pdist(teacher, squared=False, eps=self.eps) + mean_td = t_d.mean() + t_d = t_d / (mean_td + self.eps) + + d = pdist(student, squared=False, eps=self.eps) + mean_d = d.mean() + d = d / (mean_d + self.eps) + + loss = F.smooth_l1_loss(d, t_d, reduction="mean") + return loss diff --git a/paddleslim/dygraph/dist/losses/distillation_loss.py b/paddleslim/dygraph/dist/losses/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7b97bcd519436bab7fab783427914e59fef19549 --- /dev/null +++ b/paddleslim/dygraph/dist/losses/distillation_loss.py @@ -0,0 +1,136 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss +from .basic_loss import RkdDistance +from .basic_loss import RKdAngle + +__all__ = [ + "DistillationDMLLoss", + "DistillationDistanceLoss", + "DistillationRKDLoss", +] + + +class DistillationDMLLoss(DMLLoss): + """ + DistillationDMLLoss + Args: + model_name_pairs(list | tuple): model name pairs to extract submodel output. + act(string | None): activation function used to build dml loss. + axis(int): axis used to build activation function. + key(string | None): key of the tensor used to calculate loss if the submodel + output type is dict. + name(string): loss name. + """ + + def __init__(self, model_name_pairs=[], act=None, key=None, + name="loss_dml"): + super().__init__(act=act) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = super().forward(out1, out2) + return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + DistillationDistanceLoss + Args: + mode: loss mode + model_name_pairs(list | tuple): model name pairs to extract submodel output. + key(string | None): key of the tensor used to calculate loss if the submodel. + name(string): loss name. + kargs(dict): used to build corresponding loss function. + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_" + mode + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss + return loss_dict + + +class DistillationRKDLoss(nn.Layer): + """ + DistillationRKDLoss + Args: + model_name_pairs(list | tuple): model name pairs to extract submodel output. + key(string | None): key of the tensor used to calculate loss if the submodel. + eps(float): epsilon for the pdist function for RkdDistance loss. + name(string): loss name. + """ + + def __init__(self, + model_name_pairs=[], + key=None, + eps=1e-12, + name="loss_rkd"): + super().__init__() + self.model_name_pairs = model_name_pairs + self.key = key + + self.rkd_angle_loss_func = RKdAngle() + self.rkd_dist_func = RkdDistance(eps=eps) + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss_dict["{}_{}_{}_angle_{}".format(self.name, pair[0], pair[ + 1], idx)] = self.rkd_angle_loss_func(out1, out2) + + loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[ + 1], idx)] = self.rkd_dist_func(out1, out2) + return loss_dict diff --git a/tests/dygraph/test_distillation_loss.py b/tests/dygraph/test_distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a14d485a8f02033bae0c4c6c7ba2fb84d64fa89d --- /dev/null +++ b/tests/dygraph/test_distillation_loss.py @@ -0,0 +1,697 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../../") + +import copy + +import unittest +import paddle + +# basic loss +from paddleslim.dygraph.dist.losses import CombinedLoss + +# basic loss +from paddleslim.dygraph.dist.losses import DistanceLoss +from paddleslim.dygraph.dist.losses import CELoss +from paddleslim.dygraph.dist.losses import DMLLoss +from paddleslim.dygraph.dist.losses import RkdDistance +from paddleslim.dygraph.dist.losses import RKdAngle + +# distillation loss +from paddleslim.dygraph.dist.losses import DistillationDistanceLoss +from paddleslim.dygraph.dist.losses import DistillationRKDLoss +from paddleslim.dygraph.dist.losses import DistillationDMLLoss + +import numpy as np + + +class TestDistanceLoss(unittest.TestCase): + """TestDistanceLoss + TestDistanceLoss contains: + 1. unittest of basic loss + 2. unittest of distillation loss + """ + + def np_distance_loss(self, x, y, mode="l2", reduction="none"): + assert reduction in ["none", "mean", "sum"] + if isinstance(x, paddle.Tensor): + x = x.numpy() + if isinstance(y, paddle.Tensor): + y = y.numpy() + if mode == "l2": + diff = np.square(x - y) + elif mode == "l1": + diff = np.abs(x - y) + elif mode == "smooth_l1": + diff = np.abs(x - y) + diff_square = 0.5 * np.square(diff) + diff = np.where(diff >= 1, diff - 0.5, diff_square) + + if reduction == "none": + out = diff + elif reduction == "mean": + out = np.mean(diff) + elif reduction == "sum": + out = np.sum(diff) + return out + + def dist_np_distance_loss( + self, + predicts, + mode="l2", + reduction="none", + model_name_pairs=(["", ""]), + key=None, + name="loss_distance", ): + loss_dict = dict() + for idx, pair in enumerate(model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if key is not None: + out1 = out1[key] + out2 = out2[key] + loss = self.np_distance_loss( + out1, out2, mode=mode, reduction=reduction) + loss_dict["{}_{}_{}_{}_{}".format(name, mode, pair[0], pair[1], + idx)] = loss + + return loss_dict + + def test_basic_distance_loss(self): + shape = [10, 20] + x = paddle.rand(shape) + y = paddle.rand(shape) + modes = ["l1", "l2", "smooth_l1"] + reductions = ["none", "mean", "sum"] + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + for reduction in reductions: + for mode in modes: + np_result = self.np_distance_loss( + x, y, mode=mode, reduction=reduction) + loss_func = DistanceLoss(mode=mode, reduction=reduction) + pd_result = loss_func(x, y).numpy() + self.assertTrue(np.allclose(np_result, pd_result)) + + def test_distillation_distance_loss(self, ): + shape = [20, 10] + x_feat_name = "student" + y_feat_name = "teacher" + pairs = [[x_feat_name, y_feat_name]] + predicts = { + "student": paddle.rand(shape), + "teacher": paddle.rand(shape), + } + self.calc_distillation_distance_loss(predicts, pairs, key=None) + + predicts = { + "student": { + "feat": paddle.rand(shape), + }, + "teacher": { + "feat": paddle.rand(shape), + }, + } + self.calc_distillation_distance_loss(predicts, pairs, key="feat") + + def calc_distillation_distance_loss(self, predicts, pairs, key=None): + modes = ["l1", "l2", "smooth_l1"] + reductions = ["none", "mean", "sum"] + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + + for device in devices: + paddle.set_device(device) + for reduction in reductions: + for mode in modes: + loss_func = DistillationDistanceLoss( + mode=mode, + model_name_pairs=pairs, + key=key, + reduction=reduction) + np_result_dict = self.dist_np_distance_loss( + predicts, + mode=mode, + reduction=reduction, + model_name_pairs=pairs, + key=key) + pd_result_dict = loss_func(predicts, None) + for k in np_result_dict: + pd_result = pd_result_dict[k].numpy() + np_result = np_result_dict[k] + self.assertTrue(np.allclose(np_result, pd_result)) + + +class TestCELoss(unittest.TestCase): + def stable_softmax(self, x): + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + def ref_softmax(self, x, axis=-1, dtype=None): + if isinstance(x, paddle.Tensor): + x = x.numpy() + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + return np.apply_along_axis(self.stable_softmax, axis, x_t) + + def log_softmax(self, x, axis=-1): + softmax_out = np.apply_along_axis(self.stable_softmax, axis, x) + return np.log(softmax_out) + + def _cross_entropy_soft(self, softmax, label, axis, ignore_index=-1): + return (-label * np.log(softmax)).sum(axis=axis, keepdims=True) + + def np_cross_entropy_loss(self, + input, + label, + weight=None, + reduction='mean', + ignore_index=-100): + log_softmax_out = self.log_softmax(input) + input_shape = log_softmax_out.shape + N = input_shape[0] + out = np.zeros_like(label).astype(np.float64) + total_weight = 0 + ###1. compute softmax cross_entropy (with weight) + ### Note: only support hard labels. + for i in range(N): + cur_target = label[i] + if cur_target == ignore_index: + out[i] = 0 + continue + cur_weight = weight[cur_target] if weight is not None else 1 + total_weight += cur_weight + out[i] = -log_softmax_out[i][cur_target] * cur_weight + + ###2. deal with reduction + if reduction == 'sum': + return np.sum(out) + elif reduction == 'mean': + out = out.sum() / total_weight if total_weight != 0 else out.sum() + return out + elif reduction == 'none': + return out + + def np_cross_entropy_soft(self, + x, + label, + axis=-1, + weight=None, + reduction='mean', + ignore_index=-100): + if isinstance(x, paddle.Tensor): + x = x.numpy() + if isinstance(label, paddle.Tensor): + label = label.numpy() + softmax = self.ref_softmax(x, axis=axis) + #1.loss + loss = self._cross_entropy_soft(softmax, label, axis, ignore_index) + + if weight is None and reduction == 'none': + return loss + + #2.weight + weighted_loss = loss + total_weight = softmax.shape[0] # batch size + if weight is not None: + weighted_loss = np.zeros_like(loss).astype(np.float64) + total_weight = 0 + for i in range(total_weight): + cur_soft_label = label[i] + cur_weight = np.dot(weight, cur_soft_label) + total_weight += cur_weight + weighted_loss[i] = loss[i] * cur_weight + + #3.reduce + if reduction == 'none': + return weighted_loss + + elif reduction == 'mean': + weighted_loss_sum = np.sum(weighted_loss) + weighted_loss_mean = weighted_loss_sum / total_weight + return weighted_loss_mean + + else: + weighted_loss_sum = np.sum(weighted_loss) + return weighted_loss_sum + + def test_ce_loss_hard_label(self, ): + batch_size = 16 + class_num = 1000 + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + x = paddle.rand([batch_size, class_num]) + label = paddle.randint(0, class_num, shape=[batch_size, 1]) + + loss_func = CELoss() + pd_loss = loss_func(x, label).numpy() + np_loss = self.np_cross_entropy_loss(x, label) + self.assertTrue(np.allclose(np_loss, pd_loss)) + + def test_ce_loss_soft_label(self, ): + batch_size = 32 + class_num = 1000 + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + x = paddle.rand([batch_size, class_num]) + label = paddle.rand([batch_size, class_num]) + label = paddle.nn.functional.softmax(label, axis=-1) + + loss_func = CELoss(label_act=None) + pd_loss = loss_func(x, label).numpy() + np_loss = self.np_cross_entropy_soft(x, label) + self.assertTrue(np.allclose(np_loss, pd_loss)) + + +class TestDMLLoss(unittest.TestCase): + """TestDMLLoss + TestDMLLoss contains: + 1. unittest of basic loss + 2. unittest of distillation loss + """ + + def stable_softmax(self, x): + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + def ref_softmax(self, x, axis=-1, dtype=None): + if isinstance(x, paddle.Tensor): + x = x.numpy() + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + return np.apply_along_axis(self.stable_softmax, axis, x_t) + + def kldiv_loss(self, x, target, reduction="batchmean"): + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) + + if reduction == "batchmean": + if len(x.shape) > 0: + return loss.sum() / x.shape[0] + else: + return loss.sum() + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + return loss + + def np_dml_loss(self, x, target, act="softmax"): + if isinstance(x, paddle.Tensor): + x = x.numpy() + if isinstance(target, paddle.Tensor): + target = target.numpy() + soft_x = self.ref_softmax(x, axis=-1) + soft_target = self.ref_softmax(target, axis=-1) + + log_soft_x = np.log(soft_x) + log_soft_target = np.log(soft_target) + loss = (self.kldiv_loss(log_soft_x, soft_target) + self.kldiv_loss( + log_soft_target, soft_x)) / 2.0 + return loss + + def test_basic_dml_loss(self, ): + batch_size = 32 + class_num = 1000 + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + x = paddle.rand([batch_size, class_num]) + target = paddle.rand([batch_size, class_num]) + + loss_func = DMLLoss(act="softmax") + pd_loss = loss_func(x, target).numpy() + np_loss = self.np_dml_loss(x, target) + self.assertTrue(np.allclose(np_loss, pd_loss)) + + def dist_np_dml_loss( + self, + predicts, + model_name_pairs=(["", ""]), + key=None, + name="loss_dml", ): + loss_dict = dict() + for idx, pair in enumerate(model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if key is not None: + out1 = out1[key] + out2 = out2[key] + loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], + idx)] = self.np_dml_loss(out1, out2) + return loss_dict + + def calc_distillation_dml_loss(self, predicts, pairs, key=None): + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + + for device in devices: + paddle.set_device(device) + loss_func = DistillationDMLLoss( + act="softmax", model_name_pairs=pairs, key=key) + np_result_dict = self.dist_np_dml_loss( + predicts, model_name_pairs=pairs, key=key) + pd_result_dict = loss_func(predicts, None) + for k in np_result_dict: + pd_result = pd_result_dict[k].numpy() + np_result = np_result_dict[k] + self.assertTrue(np.allclose(np_result, pd_result)) + + def test_distillation_dml_loss(self, ): + shape = [20, 10] + x_feat_name = "student" + y_feat_name = "teacher" + pairs = [[x_feat_name, y_feat_name]] + predicts = { + "student": paddle.rand(shape), + "teacher": paddle.rand(shape), + } + self.calc_distillation_dml_loss(predicts, pairs, key=None) + + predicts = { + "student": { + "feat": paddle.rand(shape), + }, + "teacher": { + "feat": paddle.rand(shape), + }, + } + self.calc_distillation_dml_loss(predicts, pairs, key="feat") + + +class TestRKDLoss(unittest.TestCase): + def pdist(self, e, squared=False, eps=1e-12): + e_square = np.power(e, 2).sum(axis=1) + prod = np.matmul(e, e.transpose()) + res = ( + np.expand_dims(e_square, 1) + np.expand_dims(e_square, 0) - 2 * prod + ).clip(eps, sys.float_info.max) + if not squared: + res = np.sqrt(res) + return res + + def p_normalize(self, x, axis=1, p=2, epsilon=1e-12, keepdims=True): + xp = np.power(np.abs(x), p) + s = np.sum(xp, axis=axis, keepdims=keepdims) + r = np.maximum(np.power(s, 1.0 / p), epsilon) + return x / r + + def np_smooth_l1_loss(self, x, y): + diff = np.abs(x - y) + diff_square = 0.5 * np.square(diff) + loss = np.where(diff >= 1, diff - 0.5, diff_square).mean() + return loss + + def np_rkd_distance(self, student, teacher, eps=1e-12): + if isinstance(student, paddle.Tensor): + student = student.numpy() + if isinstance(teacher, paddle.Tensor): + teacher = teacher.numpy() + bs = student.shape[0] + student = student.reshape([bs, -1]) + teacher = teacher.reshape([bs, -1]) + + t_d = self.pdist(teacher, squared=False) + mean_td = t_d.mean() + t_d = t_d / (mean_td + eps) + + d = self.pdist(student, squared=False) + mean_d = d.mean() + d = d / (mean_d + eps) + + loss = self.np_smooth_l1_loss(d, t_d) + return loss + + def np_rkd_angle(self, student, teacher): + if isinstance(student, paddle.Tensor): + student = student.numpy() + if isinstance(teacher, paddle.Tensor): + teacher = teacher.numpy() + + # reshape for feature map distillation + bs = student.shape[0] + student = student.reshape([bs, -1]) + teacher = teacher.reshape([bs, -1]) + + td = np.expand_dims(teacher, 0) - np.expand_dims(teacher, 1) + norm_td = self.p_normalize(td, axis=2, p=2) + t_angle = np.matmul(norm_td, norm_td.transpose([0, 2, 1])).reshape( + [-1, 1]) + + sd = np.expand_dims(student, 0) - np.expand_dims(student, 1) + norm_sd = self.p_normalize(sd, axis=2, p=2) + s_angle = np.matmul(norm_sd, norm_sd.transpose([0, 2, 1])).reshape( + [-1, 1]) + + loss = self.np_smooth_l1_loss(s_angle, t_angle) + return loss + + def test_rkd_distance_loss(self, ): + batch_size = 32 + feat_dim = 100 + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + paddle.seed(0) + x = paddle.rand([batch_size, feat_dim]) + y = paddle.rand([batch_size, feat_dim]) + + loss_func = RkdDistance() + pd_loss = loss_func(x, y).numpy() + np_loss = self.np_rkd_distance(x, y) + # NOTE: sqrt is included and seed is set for stability + self.assertTrue( + np.allclose( + np_loss, pd_loss, rtol=1e-5, atol=1e-07)) + + def test_rkd_angle_loss(self, ): + batch_size = 32 + feat_dim = 100 + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + for device in devices: + paddle.set_device(device) + paddle.seed(0) + x = paddle.rand([batch_size, feat_dim]) + y = paddle.rand([batch_size, feat_dim]) + + loss_func = RKdAngle() + pd_loss = loss_func(x, y).numpy() + np_loss = self.np_rkd_angle(x, y) + # NOTE: sqrt is included and seed is set for stability + self.assertTrue(np.allclose(np_loss, pd_loss)) + + def dist_np_rkd_loss( + self, + predicts, + model_name_pairs=(["", ""]), + key=None, + name="loss_rkd", ): + loss_dict = dict() + for idx, pair in enumerate(model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if key is not None: + out1 = out1[key] + out2 = out2[key] + loss_dict["{}_{}_{}_angle_{}".format(name, pair[0], pair[ + 1], idx)] = self.np_rkd_angle(out1, out2) + + loss_dict["{}_{}_{}_dist_{}".format(name, pair[0], pair[ + 1], idx)] = self.np_rkd_distance(out1, out2) + return loss_dict + + def calc_distillation_rkd_loss(self, predicts, pairs, key=None): + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + + for device in devices: + paddle.set_device(device) + loss_func = DistillationRKDLoss(model_name_pairs=pairs, key=key) + np_result_dict = self.dist_np_rkd_loss( + predicts, model_name_pairs=pairs, key=key) + pd_result_dict = loss_func(predicts, None) + for k in np_result_dict: + pd_result = pd_result_dict[k].numpy() + np_result = np_result_dict[k] + self.assertTrue(np.allclose(np_result, pd_result, rtol=1e-5)) + + def test_distillation_rkd_loss(self, ): + shape = [32, 16] + x_feat_name = "student" + y_feat_name = "teacher" + pairs = [[x_feat_name, y_feat_name]] + paddle.seed(0) + predicts = { + "student": paddle.rand(shape), + "teacher": paddle.rand(shape), + } + self.calc_distillation_rkd_loss(predicts, pairs, key=None) + + predicts = { + "student": { + "feat": paddle.rand(shape), + }, + "teacher": { + "feat": paddle.rand(shape), + }, + } + self.calc_distillation_rkd_loss(predicts, pairs, key="feat") + + +class TestCombinedLoss(unittest.TestCase): + def stable_softmax(self, x): + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + def ref_softmax(self, x, axis=-1, dtype=None): + if isinstance(x, paddle.Tensor): + x = x.numpy() + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + return np.apply_along_axis(self.stable_softmax, axis, x_t) + + def kldiv_loss(self, x, target, reduction="batchmean"): + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) + + if reduction == "batchmean": + if len(x.shape) > 0: + return loss.sum() / x.shape[0] + else: + return loss.sum() + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + return loss + + def np_dml_loss(self, x, target, act="softmax"): + if isinstance(x, paddle.Tensor): + x = x.numpy() + if isinstance(target, paddle.Tensor): + target = target.numpy() + soft_x = self.ref_softmax(x, axis=-1) + soft_target = self.ref_softmax(target, axis=-1) + + log_soft_x = np.log(soft_x) + log_soft_target = np.log(soft_target) + loss = (self.kldiv_loss(log_soft_x, soft_target) + self.kldiv_loss( + log_soft_target, soft_x)) / 2.0 + return loss + + def dist_np_dml_loss( + self, + predicts, + model_name_pairs=(["", ""]), + key=None, + act="softmax", + name="loss_dml", ): + loss_dict = dict() + for idx, pair in enumerate(model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if key is not None: + out1 = out1[key] + out2 = out2[key] + loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], + idx)] = self.np_dml_loss(out1, out2) + return loss_dict + + def np_combined_loss(self, predicts, loss_cfg_list): + # NOTE, dml is set as the list for combined loss + loss_dict = dict() + for idx, loss_func in enumerate(loss_cfg_list): + cfg = copy.deepcopy(loss_func["DistillationDMLLoss"]) + weight = cfg.pop("weight") + loss = self.dist_np_dml_loss(predicts, **cfg) + + if isinstance(loss, np.ndarray): + loss = {"loss_{}_{}".format(str(loss), idx): loss} + else: + loss = { + "{}_{}".format(key, idx): loss[key] * weight + for key in loss + } + loss_dict.update(loss) + loss_dict["loss"] = np.sum(list(loss_dict.values())) + + return loss_dict + + def test_combined_loss(self, ): + shape = [32, 16] + x_feat_name = "student" + y_feat_name = "teacher" + pairs = [[x_feat_name, y_feat_name]] + paddle.seed(0) + predicts = { + "student": paddle.rand(shape), + "teacher": paddle.rand(shape), + } + + devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + devices.append("gpu") + + loss_cfg_list = [{ + "DistillationDMLLoss": { + "weight": 1.0, + "act": "softmax", + "model_name_pairs": pairs, + "key": None + } + }, ] + + for device in devices: + paddle.set_device(device) + loss_func = CombinedLoss(loss_config_list=loss_cfg_list) + pd_result_dict = loss_func(predicts, None) + np_result_dict = self.np_combined_loss(predicts, loss_cfg_list) + for k in pd_result_dict: + pd_result = pd_result_dict[k].numpy() + np_result = np_result_dict[k] + self.assertTrue(np.allclose(np_result, pd_result)) + + +if __name__ == '__main__': + unittest.main()