未验证 提交 15ef0c7d 编写于 作者: C cc 提交者: GitHub

Refine distillation for Segmentation (#879)

* refine distillation

* up

* add test

* fix coverage

* fix unit test error
上级 30fd1248
......@@ -16,7 +16,7 @@ import numpy as np
import collections
from collections import namedtuple
import paddle.nn as nn
from .losses import *
from . import losses
__all__ = ['Distill', 'AdaptorBase']
......@@ -39,6 +39,8 @@ class LayerConfig:
self.loss_function = 'DistillationDMLLoss'
elif loss_function in ['rkl']:
self.loss_function = 'DistillationRKDLoss'
elif hasattr(losses, loss_function):
self.loss_function = loss_function
else:
raise NotImplementedError("loss function is not support!!!")
self.weight = weight
......@@ -59,11 +61,12 @@ class AdaptorBase:
def _add_distill_hook(self, outs, mapping_layers_name, layers_type):
"""
Get output by name.
Get output by layer name.
outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for idx, (n, m) in enumerate(self.model.named_sublayers()):
if n in mapping_layers_name:
......@@ -80,6 +83,8 @@ class Distill(nn.Layer):
def __init__(self, distill_configs, student_models, teacher_models,
adaptors_S, adaptors_T):
super(Distill, self).__init__()
assert student_models.training, "The student model should be eval mode."
self._distill_configs = distill_configs
self._student_models = student_models
self._teacher_models = teacher_models
......@@ -93,6 +98,7 @@ class Distill(nn.Layer):
self.configs.append(LayerConfig(**c).__dict__)
self.distill_idx = self._get_distill_idx()
self._loss_config_list = []
for c in self.configs:
loss_config = {}
......@@ -105,24 +111,42 @@ class Distill(nn.Layer):
loss_config[str(c['loss_function'])][
'model_name_pairs'] = [['student', 'teacher']]
self._loss_config_list.append(loss_config)
self._prepare_loss()
# use self._loss_config_list to create all loss object
self.distill_loss = losses.CombinedLoss(self._loss_config_list)
def _prepare_outputs(self):
"""
Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
"""
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _prepare_hook(self, adaptors, outs_dict):
"""
Add hook.
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
adaptors._add_distill_hook(outs_dict, [layer], [layer_type])
return outs_dict
def _get_model_intermediate_output(self, adaptors, outs_dict):
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _get_distill_idx(self):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx = {}
for config in self._distill_configs:
if config['feature_type'] not in distill_idx:
......@@ -135,42 +159,13 @@ class Distill(nn.Layer):
])
return distill_idx
def _prepare_loss(self):
self.distill_loss = CombinedLoss(self._loss_config_list)
def _prepare_outputs(self):
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _post_outputs(self):
final_keys = []
for key, value in self.stu_outs_dict.items():
if len(key.split('_')) == 1:
final_keys.append(key)
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
final_distill_dict['student'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.stu_outs_dict[
feature_type + '_' + str(sidx)]
final_distill_dict['teacher'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.tea_outs_dict[
feature_type + '_' + str(tidx)]
return final_distill_dict
def forward(self, *inputs, **kwargs):
stu_batch_outs = self._student_models.forward(*inputs, **kwargs)
tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs)
if not self._teacher_models.training:
tea_batch_outs = [i.detach() for i in tea_batch_outs]
# get all target tensor
if self._adaptors_S.add_tensor == False:
self._adaptors_S.add_tensor = True
if self._adaptors_T.add_tensor == False:
......@@ -179,8 +174,50 @@ class Distill(nn.Layer):
self._adaptors_S, self.stu_outs_dict)
self.tea_outs_dict = self._get_model_intermediate_output(
self._adaptors_T, self.tea_outs_dict)
distill_inputs = self._post_outputs()
distill_inputs = self._process_outputs()
### batch is None just for now
distill_outputs = self.distill_loss(distill_inputs, None)
distill_loss = distill_outputs['loss']
return stu_batch_outs, tea_batch_outs, distill_loss
def _get_model_intermediate_output(self, adaptors, outs_dict):
"""
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _process_outputs(self):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)]
tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)]
if not self._student_models.training:
stu_out = stu_out.detach()
if not self._teacher_models.training:
tea_out = tea_out.detach()
name_str = feature_type + '_' + str(sidx) + '_' + str(tidx)
final_distill_dict['student'][name_str] = stu_out
final_distill_dict['teacher'][name_str] = tea_out
return final_distill_dict
......@@ -30,6 +30,7 @@ from .basic_loss import RKdAngle, RkdDistance
from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss
from .distillation_loss import SegPairWiseLoss, SegChannelwiseLoss
class CombinedLoss(nn.Layer):
......@@ -44,6 +45,8 @@ class CombinedLoss(nn.Layer):
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
Another example is {'DistillationDistanceLoss': {'weight': 1.0,
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
"""
def __init__(self, loss_config_list=None):
......@@ -79,5 +82,8 @@ class CombinedLoss(nn.Layer):
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
if loss_dict == {}:
loss_dict["loss"] = paddle.to_tensor(0.)
else:
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
......@@ -21,11 +21,7 @@ from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
__all__ = [
"CELoss",
"DMLLoss",
"DistanceLoss",
"RKdAngle",
"RkdDistance",
"CELoss", "DMLLoss", "DistanceLoss", "RKdAngle", "RkdDistance", "KLLoss"
]
......@@ -114,6 +110,49 @@ class DMLLoss(nn.Layer):
return loss
class KLLoss(nn.Layer):
"""
KLLoss.
Args:
act(string | None): activation function used for the input and label tensor.
It supports None, softmax and sigmoid. Default: softmax.
axis(int): the axis for the act. Default: -1.
reduction(str): the reduction params for F.kl_div. Default: mean.
"""
def __init__(self, act='softmax', axis=-1, reduction='mean'):
super().__init__()
assert act in ['softmax', 'sigmoid', None]
self.reduction = reduction
if act == 'softmax':
self.act = nn.Softmax(axis=axis)
elif act == 'sigmoid':
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, input, label):
"""
Args:
input(Tensor): The input tensor.
label(Tensor): The label tensor. The shape of label is the same as input.
Returns:
Tensor: The kl loss.
"""
assert input.shape == label.shape, \
"The shape of label should be the same as input."
if self.act is not None:
input = self.act(input)
label = self.act(label)
log_input = paddle.log(input)
loss = F.kl_div(log_input, label, reduction=self.reduction)
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss
......
......@@ -19,11 +19,14 @@ from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RkdDistance
from .basic_loss import RKdAngle
from .basic_loss import KLLoss
__all__ = [
"DistillationDMLLoss",
"DistillationDistanceLoss",
"DistillationRKDLoss",
"SegPairWiseLoss",
"SegChannelwiseLoss",
]
......@@ -66,7 +69,9 @@ class DistillationDistanceLoss(DistanceLoss):
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name.
kargs(dict): used to build corresponding loss function.
"""
......@@ -134,3 +139,86 @@ class DistillationRKDLoss(nn.Layer):
loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_dist_func(out1, out2)
return loss_dict
class SegPairWiseLoss(DistanceLoss):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
mode="l2",
reduction="mean",
name="seg_pair_wise_loss"):
super().__init__(mode=mode, reduction=reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
self.pool1 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
self.pool2 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
pool1 = self.pool1(out1)
pool2 = self.pool2(out2)
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(pool1, pool2)
return loss_dict
class SegChannelwiseLoss(KLLoss):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
act='softmax',
axis=-1,
reduction="mean",
name="seg_ch_wise_loss"):
super().__init__(act, axis, reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(out1, out2)
return loss_dict
......@@ -18,6 +18,7 @@ import copy
import unittest
import paddle
import paddle.nn.functional as F
# basic loss
from paddleslim.dygraph.dist.losses import CombinedLoss
......@@ -33,6 +34,8 @@ from paddleslim.dygraph.dist.losses import RKdAngle
from paddleslim.dygraph.dist.losses import DistillationDistanceLoss
from paddleslim.dygraph.dist.losses import DistillationRKDLoss
from paddleslim.dygraph.dist.losses import DistillationDMLLoss
from paddleslim.dygraph.dist.losses import SegPairWiseLoss
from paddleslim.dygraph.dist.losses import SegChannelwiseLoss
import numpy as np
......@@ -693,5 +696,95 @@ class TestCombinedLoss(unittest.TestCase):
self.assertTrue(np.allclose(np_result, pd_result))
class TestSegPairWiseLoss(unittest.TestCase):
def calculate_gt_loss(self, x, y):
pool_x = F.adaptive_avg_pool2d(x, [2, 2])
pool_y = F.adaptive_avg_pool2d(y, [2, 2])
loss = F.mse_loss(pool_x, pool_y)
return loss
def test_seg_pair_wise_loss(self):
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegPairWiseLoss(model_name_pairs, key)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_pair_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss(unittest.TestCase):
def init(self):
self.act_name = None
self.act_func = None
def calculate_gt_loss(self, x, y, act=None):
if act is not None:
x = act(x)
y = act(y)
x = paddle.log(x)
loss = F.kl_div(x, y)
return loss
def test_seg_pair_wise_loss(self):
self.init()
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegChannelwiseLoss(model_name_pairs, key, self.act_name)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_ch_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y, self.act_func)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "softmax"
self.act_func = F.softmax
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "sigmoid"
self.act_func = F.sigmoid
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册