未验证 提交 15ef0c7d 编写于 作者: C cc 提交者: GitHub

Refine distillation for Segmentation (#879)

* refine distillation

* up

* add test

* fix coverage

* fix unit test error
上级 30fd1248
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import collections import collections
from collections import namedtuple from collections import namedtuple
import paddle.nn as nn import paddle.nn as nn
from .losses import * from . import losses
__all__ = ['Distill', 'AdaptorBase'] __all__ = ['Distill', 'AdaptorBase']
...@@ -39,6 +39,8 @@ class LayerConfig: ...@@ -39,6 +39,8 @@ class LayerConfig:
self.loss_function = 'DistillationDMLLoss' self.loss_function = 'DistillationDMLLoss'
elif loss_function in ['rkl']: elif loss_function in ['rkl']:
self.loss_function = 'DistillationRKDLoss' self.loss_function = 'DistillationRKDLoss'
elif hasattr(losses, loss_function):
self.loss_function = loss_function
else: else:
raise NotImplementedError("loss function is not support!!!") raise NotImplementedError("loss function is not support!!!")
self.weight = weight self.weight = weight
...@@ -59,11 +61,12 @@ class AdaptorBase: ...@@ -59,11 +61,12 @@ class AdaptorBase:
def _add_distill_hook(self, outs, mapping_layers_name, layers_type): def _add_distill_hook(self, outs, mapping_layers_name, layers_type):
""" """
Get output by name. Get output by layer name.
outs(dict): save the middle outputs of model according to the name. outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers. mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss. layers_type(list): type of the middle layers to calculate distill loss.
""" """
### TODO: support DP model ### TODO: support DP model
for idx, (n, m) in enumerate(self.model.named_sublayers()): for idx, (n, m) in enumerate(self.model.named_sublayers()):
if n in mapping_layers_name: if n in mapping_layers_name:
...@@ -80,6 +83,8 @@ class Distill(nn.Layer): ...@@ -80,6 +83,8 @@ class Distill(nn.Layer):
def __init__(self, distill_configs, student_models, teacher_models, def __init__(self, distill_configs, student_models, teacher_models,
adaptors_S, adaptors_T): adaptors_S, adaptors_T):
super(Distill, self).__init__() super(Distill, self).__init__()
assert student_models.training, "The student model should be eval mode."
self._distill_configs = distill_configs self._distill_configs = distill_configs
self._student_models = student_models self._student_models = student_models
self._teacher_models = teacher_models self._teacher_models = teacher_models
...@@ -93,6 +98,7 @@ class Distill(nn.Layer): ...@@ -93,6 +98,7 @@ class Distill(nn.Layer):
self.configs.append(LayerConfig(**c).__dict__) self.configs.append(LayerConfig(**c).__dict__)
self.distill_idx = self._get_distill_idx() self.distill_idx = self._get_distill_idx()
self._loss_config_list = [] self._loss_config_list = []
for c in self.configs: for c in self.configs:
loss_config = {} loss_config = {}
...@@ -105,24 +111,42 @@ class Distill(nn.Layer): ...@@ -105,24 +111,42 @@ class Distill(nn.Layer):
loss_config[str(c['loss_function'])][ loss_config[str(c['loss_function'])][
'model_name_pairs'] = [['student', 'teacher']] 'model_name_pairs'] = [['student', 'teacher']]
self._loss_config_list.append(loss_config) self._loss_config_list.append(loss_config)
self._prepare_loss()
# use self._loss_config_list to create all loss object
self.distill_loss = losses.CombinedLoss(self._loss_config_list)
def _prepare_outputs(self):
"""
Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
"""
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _prepare_hook(self, adaptors, outs_dict): def _prepare_hook(self, adaptors, outs_dict):
"""
Add hook.
"""
mapping_layers = adaptors.mapping_layers() mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items(): for layer_type, layer in mapping_layers.items():
if isinstance(layer, str): if isinstance(layer, str):
adaptors._add_distill_hook(outs_dict, [layer], [layer_type]) adaptors._add_distill_hook(outs_dict, [layer], [layer_type])
return outs_dict return outs_dict
def _get_model_intermediate_output(self, adaptors, outs_dict):
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _get_distill_idx(self): def _get_distill_idx(self):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx = {} distill_idx = {}
for config in self._distill_configs: for config in self._distill_configs:
if config['feature_type'] not in distill_idx: if config['feature_type'] not in distill_idx:
...@@ -135,42 +159,13 @@ class Distill(nn.Layer): ...@@ -135,42 +159,13 @@ class Distill(nn.Layer):
]) ])
return distill_idx return distill_idx
def _prepare_loss(self):
self.distill_loss = CombinedLoss(self._loss_config_list)
def _prepare_outputs(self):
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _post_outputs(self):
final_keys = []
for key, value in self.stu_outs_dict.items():
if len(key.split('_')) == 1:
final_keys.append(key)
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
final_distill_dict['student'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.stu_outs_dict[
feature_type + '_' + str(sidx)]
final_distill_dict['teacher'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.tea_outs_dict[
feature_type + '_' + str(tidx)]
return final_distill_dict
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
stu_batch_outs = self._student_models.forward(*inputs, **kwargs) stu_batch_outs = self._student_models.forward(*inputs, **kwargs)
tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs) tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs)
if not self._teacher_models.training:
tea_batch_outs = [i.detach() for i in tea_batch_outs]
# get all target tensor
if self._adaptors_S.add_tensor == False: if self._adaptors_S.add_tensor == False:
self._adaptors_S.add_tensor = True self._adaptors_S.add_tensor = True
if self._adaptors_T.add_tensor == False: if self._adaptors_T.add_tensor == False:
...@@ -179,8 +174,50 @@ class Distill(nn.Layer): ...@@ -179,8 +174,50 @@ class Distill(nn.Layer):
self._adaptors_S, self.stu_outs_dict) self._adaptors_S, self.stu_outs_dict)
self.tea_outs_dict = self._get_model_intermediate_output( self.tea_outs_dict = self._get_model_intermediate_output(
self._adaptors_T, self.tea_outs_dict) self._adaptors_T, self.tea_outs_dict)
distill_inputs = self._post_outputs()
distill_inputs = self._process_outputs()
### batch is None just for now ### batch is None just for now
distill_outputs = self.distill_loss(distill_inputs, None) distill_outputs = self.distill_loss(distill_inputs, None)
distill_loss = distill_outputs['loss'] distill_loss = distill_outputs['loss']
return stu_batch_outs, tea_batch_outs, distill_loss return stu_batch_outs, tea_batch_outs, distill_loss
def _get_model_intermediate_output(self, adaptors, outs_dict):
"""
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _process_outputs(self):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)]
tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)]
if not self._student_models.training:
stu_out = stu_out.detach()
if not self._teacher_models.training:
tea_out = tea_out.detach()
name_str = feature_type + '_' + str(sidx) + '_' + str(tidx)
final_distill_dict['student'][name_str] = stu_out
final_distill_dict['teacher'][name_str] = tea_out
return final_distill_dict
...@@ -30,6 +30,7 @@ from .basic_loss import RKdAngle, RkdDistance ...@@ -30,6 +30,7 @@ from .basic_loss import RKdAngle, RkdDistance
from .distillation_loss import DistillationDistanceLoss from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss from .distillation_loss import DistillationRKDLoss
from .distillation_loss import SegPairWiseLoss, SegChannelwiseLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
...@@ -44,6 +45,8 @@ class CombinedLoss(nn.Layer): ...@@ -44,6 +45,8 @@ class CombinedLoss(nn.Layer):
act: "softmax" act: "softmax"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
Another example is {'DistillationDistanceLoss': {'weight': 1.0,
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
""" """
def __init__(self, loss_config_list=None): def __init__(self, loss_config_list=None):
...@@ -79,5 +82,8 @@ class CombinedLoss(nn.Layer): ...@@ -79,5 +82,8 @@ class CombinedLoss(nn.Layer):
for key in loss for key in loss
} }
loss_dict.update(loss) loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) if loss_dict == {}:
loss_dict["loss"] = paddle.to_tensor(0.)
else:
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict return loss_dict
...@@ -21,11 +21,7 @@ from paddle.nn import MSELoss as L2Loss ...@@ -21,11 +21,7 @@ from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss from paddle.nn import SmoothL1Loss
__all__ = [ __all__ = [
"CELoss", "CELoss", "DMLLoss", "DistanceLoss", "RKdAngle", "RkdDistance", "KLLoss"
"DMLLoss",
"DistanceLoss",
"RKdAngle",
"RkdDistance",
] ]
...@@ -114,6 +110,49 @@ class DMLLoss(nn.Layer): ...@@ -114,6 +110,49 @@ class DMLLoss(nn.Layer):
return loss return loss
class KLLoss(nn.Layer):
"""
KLLoss.
Args:
act(string | None): activation function used for the input and label tensor.
It supports None, softmax and sigmoid. Default: softmax.
axis(int): the axis for the act. Default: -1.
reduction(str): the reduction params for F.kl_div. Default: mean.
"""
def __init__(self, act='softmax', axis=-1, reduction='mean'):
super().__init__()
assert act in ['softmax', 'sigmoid', None]
self.reduction = reduction
if act == 'softmax':
self.act = nn.Softmax(axis=axis)
elif act == 'sigmoid':
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, input, label):
"""
Args:
input(Tensor): The input tensor.
label(Tensor): The label tensor. The shape of label is the same as input.
Returns:
Tensor: The kl loss.
"""
assert input.shape == label.shape, \
"The shape of label should be the same as input."
if self.act is not None:
input = self.act(input)
label = self.act(label)
log_input = paddle.log(input)
loss = F.kl_div(log_input, label, reduction=self.reduction)
return loss
class DistanceLoss(nn.Layer): class DistanceLoss(nn.Layer):
""" """
DistanceLoss DistanceLoss
......
...@@ -19,11 +19,14 @@ from .basic_loss import DMLLoss ...@@ -19,11 +19,14 @@ from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss from .basic_loss import DistanceLoss
from .basic_loss import RkdDistance from .basic_loss import RkdDistance
from .basic_loss import RKdAngle from .basic_loss import RKdAngle
from .basic_loss import KLLoss
__all__ = [ __all__ = [
"DistillationDMLLoss", "DistillationDMLLoss",
"DistillationDistanceLoss", "DistillationDistanceLoss",
"DistillationRKDLoss", "DistillationRKDLoss",
"SegPairWiseLoss",
"SegChannelwiseLoss",
] ]
...@@ -66,7 +69,9 @@ class DistillationDistanceLoss(DistanceLoss): ...@@ -66,7 +69,9 @@ class DistillationDistanceLoss(DistanceLoss):
Args: Args:
mode: loss mode mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output. model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel. key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name. name(string): loss name.
kargs(dict): used to build corresponding loss function. kargs(dict): used to build corresponding loss function.
""" """
...@@ -134,3 +139,86 @@ class DistillationRKDLoss(nn.Layer): ...@@ -134,3 +139,86 @@ class DistillationRKDLoss(nn.Layer):
loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[ loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_dist_func(out1, out2) 1], idx)] = self.rkd_dist_func(out1, out2)
return loss_dict return loss_dict
class SegPairWiseLoss(DistanceLoss):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
mode="l2",
reduction="mean",
name="seg_pair_wise_loss"):
super().__init__(mode=mode, reduction=reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
self.pool1 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
self.pool2 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
pool1 = self.pool1(out1)
pool2 = self.pool2(out2)
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(pool1, pool2)
return loss_dict
class SegChannelwiseLoss(KLLoss):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
act='softmax',
axis=-1,
reduction="mean",
name="seg_ch_wise_loss"):
super().__init__(act, axis, reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(out1, out2)
return loss_dict
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import unittest import unittest
import paddle import paddle
import paddle.nn.functional as F
# basic loss # basic loss
from paddleslim.dygraph.dist.losses import CombinedLoss from paddleslim.dygraph.dist.losses import CombinedLoss
...@@ -33,6 +34,8 @@ from paddleslim.dygraph.dist.losses import RKdAngle ...@@ -33,6 +34,8 @@ from paddleslim.dygraph.dist.losses import RKdAngle
from paddleslim.dygraph.dist.losses import DistillationDistanceLoss from paddleslim.dygraph.dist.losses import DistillationDistanceLoss
from paddleslim.dygraph.dist.losses import DistillationRKDLoss from paddleslim.dygraph.dist.losses import DistillationRKDLoss
from paddleslim.dygraph.dist.losses import DistillationDMLLoss from paddleslim.dygraph.dist.losses import DistillationDMLLoss
from paddleslim.dygraph.dist.losses import SegPairWiseLoss
from paddleslim.dygraph.dist.losses import SegChannelwiseLoss
import numpy as np import numpy as np
...@@ -693,5 +696,95 @@ class TestCombinedLoss(unittest.TestCase): ...@@ -693,5 +696,95 @@ class TestCombinedLoss(unittest.TestCase):
self.assertTrue(np.allclose(np_result, pd_result)) self.assertTrue(np.allclose(np_result, pd_result))
class TestSegPairWiseLoss(unittest.TestCase):
def calculate_gt_loss(self, x, y):
pool_x = F.adaptive_avg_pool2d(x, [2, 2])
pool_y = F.adaptive_avg_pool2d(y, [2, 2])
loss = F.mse_loss(pool_x, pool_y)
return loss
def test_seg_pair_wise_loss(self):
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegPairWiseLoss(model_name_pairs, key)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_pair_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss(unittest.TestCase):
def init(self):
self.act_name = None
self.act_func = None
def calculate_gt_loss(self, x, y, act=None):
if act is not None:
x = act(x)
y = act(y)
x = paddle.log(x)
loss = F.kl_div(x, y)
return loss
def test_seg_pair_wise_loss(self):
self.init()
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegChannelwiseLoss(model_name_pairs, key, self.act_name)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_ch_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y, self.act_func)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "softmax"
self.act_func = F.softmax
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "sigmoid"
self.act_func = F.sigmoid
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册