未验证 提交 dfe5d3f7 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

[distill]add distillation losses (#789)

上级 c9c0e83f
......@@ -11,3 +11,73 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
from . import basic_loss
from . import distillation_loss
from .basic_loss import L1Loss
from .basic_loss import L2Loss
from .basic_loss import SmoothL1Loss
from .basic_loss import CELoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RKdAngle, RkdDistance
from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss
class CombinedLoss(nn.Layer):
"""
CombinedLoss: a combination of loss function.
Args:
loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
"""
def __init__(self, loss_config_list=None):
super().__init__()
loss_config_list = copy.deepcopy(loss_config_list)
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
supported_loss_list = basic_loss.__all__ + distillation_loss.__all__
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
assert name in supported_loss_list, \
"loss name must be in {} but got: {}".format(name, supported_loss_list)
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
weight = self.loss_weight[idx]
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss * weight}
else:
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
__all__ = [
"CELoss",
"DMLLoss",
"DistanceLoss",
"RKdAngle",
"RkdDistance",
]
class CELoss(nn.Layer):
"""
CELoss: cross entropy loss
Args:
epsilon(float | None): label smooth epsilon. If it is None or not in range (0, 1),
then label smooth will not be used.
label_act(string | None): activation function, it works when the label is also the logits
rather than the groundtruth label.
axis(int): axis used to calculate cross entropy loss.
"""
def __init__(self, epsilon=None, label_act="softmax", axis=-1):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
assert label_act in ["softmax", None]
if epsilon is not None and (epsilon >= 1 or epsilon <= 0):
epsilon = None
self.epsilon = epsilon
self.label_act = label_act
self.axis = axis
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
assert len(x.shape) == len(label.shape), \
"x and label shape length should be same but got {} for x.shape and {} for label.shape".format(x.shape, label.shape)
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=self.axis)
loss = paddle.sum(x * label, axis=self.axis)
else:
if label.shape[self.axis] == x.shape[self.axis]:
if self.label_act == "softmax":
label = F.softmax(label, axis=self.axis)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(
x, label=label, soft_label=soft_label, axis=self.axis)
loss = loss.mean()
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
Args:
act(string | None): activation function used to activate the input tensor
axis(int): axis used to build activation function
"""
def __init__(self, act=None, axis=-1):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=axis)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss
Args:
mode: loss mode
kargs(dict): used to build corresponding loss function, for more details, please
refer to:
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
L2Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(axis=1)
prod = paddle.mm(e, e.t())
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clip(
min=eps)
if not squared:
res = res.sqrt()
return res
class RKdAngle(nn.Layer):
"""
RKdAngle loss, see https://arxiv.org/abs/1904.05068
"""
def __init__(self):
super().__init__()
def forward(self, student, teacher):
# reshape for feature map distillation
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, axis=2)
t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape(
[-1, 1])
sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, axis=2)
s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape(
[-1, 1])
loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean')
return loss
class RkdDistance(nn.Layer):
"""
RkdDistance loss, see https://arxiv.org/abs/1904.05068
Args:
eps(float): epsilon for the pdist function
"""
def __init__(self, eps=1e-12):
super().__init__()
self.eps = eps
def forward(self, student, teacher):
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])
t_d = pdist(teacher, squared=False, eps=self.eps)
mean_td = t_d.mean()
t_d = t_d / (mean_td + self.eps)
d = pdist(student, squared=False, eps=self.eps)
mean_d = d.mean()
d = d / (mean_d + self.eps)
loss = F.smooth_l1_loss(d, t_d, reduction="mean")
return loss
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RkdDistance
from .basic_loss import RKdAngle
__all__ = [
"DistillationDMLLoss",
"DistillationDistanceLoss",
"DistillationRKDLoss",
]
class DistillationDMLLoss(DMLLoss):
"""
DistillationDMLLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
act(string | None): activation function used to build dml loss.
axis(int): axis used to build activation function.
key(string | None): key of the tensor used to calculate loss if the submodel
output type is dict.
name(string): loss name.
"""
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = super().forward(out1, out2)
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
DistillationDistanceLoss
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
name(string): loss name.
kargs(dict): used to build corresponding loss function.
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_" + mode
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class DistillationRKDLoss(nn.Layer):
"""
DistillationRKDLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
eps(float): epsilon for the pdist function for RkdDistance loss.
name(string): loss name.
"""
def __init__(self,
model_name_pairs=[],
key=None,
eps=1e-12,
name="loss_rkd"):
super().__init__()
self.model_name_pairs = model_name_pairs
self.key = key
self.rkd_angle_loss_func = RKdAngle()
self.rkd_dist_func = RkdDistance(eps=eps)
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss_dict["{}_{}_{}_angle_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_angle_loss_func(out1, out2)
loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_dist_func(out1, out2)
return loss_dict
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../../")
import copy
import unittest
import paddle
# basic loss
from paddleslim.dygraph.dist.losses import CombinedLoss
# basic loss
from paddleslim.dygraph.dist.losses import DistanceLoss
from paddleslim.dygraph.dist.losses import CELoss
from paddleslim.dygraph.dist.losses import DMLLoss
from paddleslim.dygraph.dist.losses import RkdDistance
from paddleslim.dygraph.dist.losses import RKdAngle
# distillation loss
from paddleslim.dygraph.dist.losses import DistillationDistanceLoss
from paddleslim.dygraph.dist.losses import DistillationRKDLoss
from paddleslim.dygraph.dist.losses import DistillationDMLLoss
import numpy as np
class TestDistanceLoss(unittest.TestCase):
"""TestDistanceLoss
TestDistanceLoss contains:
1. unittest of basic loss
2. unittest of distillation loss
"""
def np_distance_loss(self, x, y, mode="l2", reduction="none"):
assert reduction in ["none", "mean", "sum"]
if isinstance(x, paddle.Tensor):
x = x.numpy()
if isinstance(y, paddle.Tensor):
y = y.numpy()
if mode == "l2":
diff = np.square(x - y)
elif mode == "l1":
diff = np.abs(x - y)
elif mode == "smooth_l1":
diff = np.abs(x - y)
diff_square = 0.5 * np.square(diff)
diff = np.where(diff >= 1, diff - 0.5, diff_square)
if reduction == "none":
out = diff
elif reduction == "mean":
out = np.mean(diff)
elif reduction == "sum":
out = np.sum(diff)
return out
def dist_np_distance_loss(
self,
predicts,
mode="l2",
reduction="none",
model_name_pairs=(["", ""]),
key=None,
name="loss_distance", ):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss = self.np_distance_loss(
out1, out2, mode=mode, reduction=reduction)
loss_dict["{}_{}_{}_{}_{}".format(name, mode, pair[0], pair[1],
idx)] = loss
return loss_dict
def test_basic_distance_loss(self):
shape = [10, 20]
x = paddle.rand(shape)
y = paddle.rand(shape)
modes = ["l1", "l2", "smooth_l1"]
reductions = ["none", "mean", "sum"]
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
for reduction in reductions:
for mode in modes:
np_result = self.np_distance_loss(
x, y, mode=mode, reduction=reduction)
loss_func = DistanceLoss(mode=mode, reduction=reduction)
pd_result = loss_func(x, y).numpy()
self.assertTrue(np.allclose(np_result, pd_result))
def test_distillation_distance_loss(self, ):
shape = [20, 10]
x_feat_name = "student"
y_feat_name = "teacher"
pairs = [[x_feat_name, y_feat_name]]
predicts = {
"student": paddle.rand(shape),
"teacher": paddle.rand(shape),
}
self.calc_distillation_distance_loss(predicts, pairs, key=None)
predicts = {
"student": {
"feat": paddle.rand(shape),
},
"teacher": {
"feat": paddle.rand(shape),
},
}
self.calc_distillation_distance_loss(predicts, pairs, key="feat")
def calc_distillation_distance_loss(self, predicts, pairs, key=None):
modes = ["l1", "l2", "smooth_l1"]
reductions = ["none", "mean", "sum"]
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
for reduction in reductions:
for mode in modes:
loss_func = DistillationDistanceLoss(
mode=mode,
model_name_pairs=pairs,
key=key,
reduction=reduction)
np_result_dict = self.dist_np_distance_loss(
predicts,
mode=mode,
reduction=reduction,
model_name_pairs=pairs,
key=key)
pd_result_dict = loss_func(predicts, None)
for k in np_result_dict:
pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k]
self.assertTrue(np.allclose(np_result, pd_result))
class TestCELoss(unittest.TestCase):
def stable_softmax(self, x):
shiftx = (x - np.max(x)).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def ref_softmax(self, x, axis=-1, dtype=None):
if isinstance(x, paddle.Tensor):
x = x.numpy()
x_t = x.copy()
if dtype is not None:
x_t = x_t.astype(dtype)
return np.apply_along_axis(self.stable_softmax, axis, x_t)
def log_softmax(self, x, axis=-1):
softmax_out = np.apply_along_axis(self.stable_softmax, axis, x)
return np.log(softmax_out)
def _cross_entropy_soft(self, softmax, label, axis, ignore_index=-1):
return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)
def np_cross_entropy_loss(self,
input,
label,
weight=None,
reduction='mean',
ignore_index=-100):
log_softmax_out = self.log_softmax(input)
input_shape = log_softmax_out.shape
N = input_shape[0]
out = np.zeros_like(label).astype(np.float64)
total_weight = 0
###1. compute softmax cross_entropy (with weight)
### Note: only support hard labels.
for i in range(N):
cur_target = label[i]
if cur_target == ignore_index:
out[i] = 0
continue
cur_weight = weight[cur_target] if weight is not None else 1
total_weight += cur_weight
out[i] = -log_softmax_out[i][cur_target] * cur_weight
###2. deal with reduction
if reduction == 'sum':
return np.sum(out)
elif reduction == 'mean':
out = out.sum() / total_weight if total_weight != 0 else out.sum()
return out
elif reduction == 'none':
return out
def np_cross_entropy_soft(self,
x,
label,
axis=-1,
weight=None,
reduction='mean',
ignore_index=-100):
if isinstance(x, paddle.Tensor):
x = x.numpy()
if isinstance(label, paddle.Tensor):
label = label.numpy()
softmax = self.ref_softmax(x, axis=axis)
#1.loss
loss = self._cross_entropy_soft(softmax, label, axis, ignore_index)
if weight is None and reduction == 'none':
return loss
#2.weight
weighted_loss = loss
total_weight = softmax.shape[0] # batch size
if weight is not None:
weighted_loss = np.zeros_like(loss).astype(np.float64)
total_weight = 0
for i in range(total_weight):
cur_soft_label = label[i]
cur_weight = np.dot(weight, cur_soft_label)
total_weight += cur_weight
weighted_loss[i] = loss[i] * cur_weight
#3.reduce
if reduction == 'none':
return weighted_loss
elif reduction == 'mean':
weighted_loss_sum = np.sum(weighted_loss)
weighted_loss_mean = weighted_loss_sum / total_weight
return weighted_loss_mean
else:
weighted_loss_sum = np.sum(weighted_loss)
return weighted_loss_sum
def test_ce_loss_hard_label(self, ):
batch_size = 16
class_num = 1000
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
x = paddle.rand([batch_size, class_num])
label = paddle.randint(0, class_num, shape=[batch_size, 1])
loss_func = CELoss()
pd_loss = loss_func(x, label).numpy()
np_loss = self.np_cross_entropy_loss(x, label)
self.assertTrue(np.allclose(np_loss, pd_loss))
def test_ce_loss_soft_label(self, ):
batch_size = 32
class_num = 1000
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
x = paddle.rand([batch_size, class_num])
label = paddle.rand([batch_size, class_num])
label = paddle.nn.functional.softmax(label, axis=-1)
loss_func = CELoss(label_act=None)
pd_loss = loss_func(x, label).numpy()
np_loss = self.np_cross_entropy_soft(x, label)
self.assertTrue(np.allclose(np_loss, pd_loss))
class TestDMLLoss(unittest.TestCase):
"""TestDMLLoss
TestDMLLoss contains:
1. unittest of basic loss
2. unittest of distillation loss
"""
def stable_softmax(self, x):
shiftx = (x - np.max(x)).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def ref_softmax(self, x, axis=-1, dtype=None):
if isinstance(x, paddle.Tensor):
x = x.numpy()
x_t = x.copy()
if dtype is not None:
x_t = x_t.astype(dtype)
return np.apply_along_axis(self.stable_softmax, axis, x_t)
def kldiv_loss(self, x, target, reduction="batchmean"):
output = target * (np.log(target) - x)
loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean":
if len(x.shape) > 0:
return loss.sum() / x.shape[0]
else:
return loss.sum()
if reduction == "mean":
return loss.mean()
if reduction == "sum":
return loss.sum()
return loss
def np_dml_loss(self, x, target, act="softmax"):
if isinstance(x, paddle.Tensor):
x = x.numpy()
if isinstance(target, paddle.Tensor):
target = target.numpy()
soft_x = self.ref_softmax(x, axis=-1)
soft_target = self.ref_softmax(target, axis=-1)
log_soft_x = np.log(soft_x)
log_soft_target = np.log(soft_target)
loss = (self.kldiv_loss(log_soft_x, soft_target) + self.kldiv_loss(
log_soft_target, soft_x)) / 2.0
return loss
def test_basic_dml_loss(self, ):
batch_size = 32
class_num = 1000
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
x = paddle.rand([batch_size, class_num])
target = paddle.rand([batch_size, class_num])
loss_func = DMLLoss(act="softmax")
pd_loss = loss_func(x, target).numpy()
np_loss = self.np_dml_loss(x, target)
self.assertTrue(np.allclose(np_loss, pd_loss))
def dist_np_dml_loss(
self,
predicts,
model_name_pairs=(["", ""]),
key=None,
name="loss_dml", ):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1],
idx)] = self.np_dml_loss(out1, out2)
return loss_dict
def calc_distillation_dml_loss(self, predicts, pairs, key=None):
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = DistillationDMLLoss(
act="softmax", model_name_pairs=pairs, key=key)
np_result_dict = self.dist_np_dml_loss(
predicts, model_name_pairs=pairs, key=key)
pd_result_dict = loss_func(predicts, None)
for k in np_result_dict:
pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k]
self.assertTrue(np.allclose(np_result, pd_result))
def test_distillation_dml_loss(self, ):
shape = [20, 10]
x_feat_name = "student"
y_feat_name = "teacher"
pairs = [[x_feat_name, y_feat_name]]
predicts = {
"student": paddle.rand(shape),
"teacher": paddle.rand(shape),
}
self.calc_distillation_dml_loss(predicts, pairs, key=None)
predicts = {
"student": {
"feat": paddle.rand(shape),
},
"teacher": {
"feat": paddle.rand(shape),
},
}
self.calc_distillation_dml_loss(predicts, pairs, key="feat")
class TestRKDLoss(unittest.TestCase):
def pdist(self, e, squared=False, eps=1e-12):
e_square = np.power(e, 2).sum(axis=1)
prod = np.matmul(e, e.transpose())
res = (
np.expand_dims(e_square, 1) + np.expand_dims(e_square, 0) - 2 * prod
).clip(eps, sys.float_info.max)
if not squared:
res = np.sqrt(res)
return res
def p_normalize(self, x, axis=1, p=2, epsilon=1e-12, keepdims=True):
xp = np.power(np.abs(x), p)
s = np.sum(xp, axis=axis, keepdims=keepdims)
r = np.maximum(np.power(s, 1.0 / p), epsilon)
return x / r
def np_smooth_l1_loss(self, x, y):
diff = np.abs(x - y)
diff_square = 0.5 * np.square(diff)
loss = np.where(diff >= 1, diff - 0.5, diff_square).mean()
return loss
def np_rkd_distance(self, student, teacher, eps=1e-12):
if isinstance(student, paddle.Tensor):
student = student.numpy()
if isinstance(teacher, paddle.Tensor):
teacher = teacher.numpy()
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])
t_d = self.pdist(teacher, squared=False)
mean_td = t_d.mean()
t_d = t_d / (mean_td + eps)
d = self.pdist(student, squared=False)
mean_d = d.mean()
d = d / (mean_d + eps)
loss = self.np_smooth_l1_loss(d, t_d)
return loss
def np_rkd_angle(self, student, teacher):
if isinstance(student, paddle.Tensor):
student = student.numpy()
if isinstance(teacher, paddle.Tensor):
teacher = teacher.numpy()
# reshape for feature map distillation
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])
td = np.expand_dims(teacher, 0) - np.expand_dims(teacher, 1)
norm_td = self.p_normalize(td, axis=2, p=2)
t_angle = np.matmul(norm_td, norm_td.transpose([0, 2, 1])).reshape(
[-1, 1])
sd = np.expand_dims(student, 0) - np.expand_dims(student, 1)
norm_sd = self.p_normalize(sd, axis=2, p=2)
s_angle = np.matmul(norm_sd, norm_sd.transpose([0, 2, 1])).reshape(
[-1, 1])
loss = self.np_smooth_l1_loss(s_angle, t_angle)
return loss
def test_rkd_distance_loss(self, ):
batch_size = 32
feat_dim = 100
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
paddle.seed(0)
x = paddle.rand([batch_size, feat_dim])
y = paddle.rand([batch_size, feat_dim])
loss_func = RkdDistance()
pd_loss = loss_func(x, y).numpy()
np_loss = self.np_rkd_distance(x, y)
# NOTE: sqrt is included and seed is set for stability
self.assertTrue(
np.allclose(
np_loss, pd_loss, rtol=1e-5, atol=1e-07))
def test_rkd_angle_loss(self, ):
batch_size = 32
feat_dim = 100
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
paddle.seed(0)
x = paddle.rand([batch_size, feat_dim])
y = paddle.rand([batch_size, feat_dim])
loss_func = RKdAngle()
pd_loss = loss_func(x, y).numpy()
np_loss = self.np_rkd_angle(x, y)
# NOTE: sqrt is included and seed is set for stability
self.assertTrue(np.allclose(np_loss, pd_loss))
def dist_np_rkd_loss(
self,
predicts,
model_name_pairs=(["", ""]),
key=None,
name="loss_rkd", ):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_angle_{}".format(name, pair[0], pair[
1], idx)] = self.np_rkd_angle(out1, out2)
loss_dict["{}_{}_{}_dist_{}".format(name, pair[0], pair[
1], idx)] = self.np_rkd_distance(out1, out2)
return loss_dict
def calc_distillation_rkd_loss(self, predicts, pairs, key=None):
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = DistillationRKDLoss(model_name_pairs=pairs, key=key)
np_result_dict = self.dist_np_rkd_loss(
predicts, model_name_pairs=pairs, key=key)
pd_result_dict = loss_func(predicts, None)
for k in np_result_dict:
pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k]
self.assertTrue(np.allclose(np_result, pd_result, rtol=1e-5))
def test_distillation_rkd_loss(self, ):
shape = [32, 16]
x_feat_name = "student"
y_feat_name = "teacher"
pairs = [[x_feat_name, y_feat_name]]
paddle.seed(0)
predicts = {
"student": paddle.rand(shape),
"teacher": paddle.rand(shape),
}
self.calc_distillation_rkd_loss(predicts, pairs, key=None)
predicts = {
"student": {
"feat": paddle.rand(shape),
},
"teacher": {
"feat": paddle.rand(shape),
},
}
self.calc_distillation_rkd_loss(predicts, pairs, key="feat")
class TestCombinedLoss(unittest.TestCase):
def stable_softmax(self, x):
shiftx = (x - np.max(x)).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def ref_softmax(self, x, axis=-1, dtype=None):
if isinstance(x, paddle.Tensor):
x = x.numpy()
x_t = x.copy()
if dtype is not None:
x_t = x_t.astype(dtype)
return np.apply_along_axis(self.stable_softmax, axis, x_t)
def kldiv_loss(self, x, target, reduction="batchmean"):
output = target * (np.log(target) - x)
loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean":
if len(x.shape) > 0:
return loss.sum() / x.shape[0]
else:
return loss.sum()
if reduction == "mean":
return loss.mean()
if reduction == "sum":
return loss.sum()
return loss
def np_dml_loss(self, x, target, act="softmax"):
if isinstance(x, paddle.Tensor):
x = x.numpy()
if isinstance(target, paddle.Tensor):
target = target.numpy()
soft_x = self.ref_softmax(x, axis=-1)
soft_target = self.ref_softmax(target, axis=-1)
log_soft_x = np.log(soft_x)
log_soft_target = np.log(soft_target)
loss = (self.kldiv_loss(log_soft_x, soft_target) + self.kldiv_loss(
log_soft_target, soft_x)) / 2.0
return loss
def dist_np_dml_loss(
self,
predicts,
model_name_pairs=(["", ""]),
key=None,
act="softmax",
name="loss_dml", ):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1],
idx)] = self.np_dml_loss(out1, out2)
return loss_dict
def np_combined_loss(self, predicts, loss_cfg_list):
# NOTE, dml is set as the list for combined loss
loss_dict = dict()
for idx, loss_func in enumerate(loss_cfg_list):
cfg = copy.deepcopy(loss_func["DistillationDMLLoss"])
weight = cfg.pop("weight")
loss = self.dist_np_dml_loss(predicts, **cfg)
if isinstance(loss, np.ndarray):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
else:
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = np.sum(list(loss_dict.values()))
return loss_dict
def test_combined_loss(self, ):
shape = [32, 16]
x_feat_name = "student"
y_feat_name = "teacher"
pairs = [[x_feat_name, y_feat_name]]
paddle.seed(0)
predicts = {
"student": paddle.rand(shape),
"teacher": paddle.rand(shape),
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
loss_cfg_list = [{
"DistillationDMLLoss": {
"weight": 1.0,
"act": "softmax",
"model_name_pairs": pairs,
"key": None
}
}, ]
for device in devices:
paddle.set_device(device)
loss_func = CombinedLoss(loss_config_list=loss_cfg_list)
pd_result_dict = loss_func(predicts, None)
np_result_dict = self.np_combined_loss(predicts, loss_cfg_list)
for k in pd_result_dict:
pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k]
self.assertTrue(np.allclose(np_result, pd_result))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册