未验证 提交 73c4f2b7 编写于 作者: W whs 提交者: GitHub

Fix distillation for soft label. (#16538)

test=develop
上级 3e6aa498
...@@ -19,7 +19,7 @@ from .... import Program ...@@ -19,7 +19,7 @@ from .... import Program
from .... import program_guard from .... import program_guard
from .... import regularizer from .... import regularizer
__all__ = ['FSPDistiller', 'L2Distiller'] __all__ = ['FSPDistiller', 'L2Distiller', 'SoftLabelDistiller']
class L2Distiller(object): class L2Distiller(object):
...@@ -186,3 +186,91 @@ class FSPDistillerPass(object): ...@@ -186,3 +186,91 @@ class FSPDistillerPass(object):
def _fsp_matrix(self, fea_map_0, fea_map_1): def _fsp_matrix(self, fea_map_0, fea_map_1):
return layers.fsp_matrix(fea_map_0, fea_map_1) return layers.fsp_matrix(fea_map_0, fea_map_1)
class SoftLabelDistiller(object):
"""
Combine two layers from student net and teacher net by softmax_with_cross_entropy loss.
And add the loss into the total loss using for distillation training.
"""
def __init__(self,
student_feature_map=None,
teacher_feature_map=None,
student_temperature=1.0,
teacher_temperature=1.0,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy. default: 1.0
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy. default: 1.0
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.distillation_loss_weight = distillation_loss_weight
self.student_temperature = student_temperature
self.teacher_temperature = teacher_temperature
def distiller_loss(self, graph):
"""
Modify graph inplace to add softmax_with_cross_entropy loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass = SoftLabelDistillerPass(
self.student_feature_map, self.teacher_feature_map,
self.student_temperature, self.teacher_temperature,
self.distillation_loss_weight)
dis_graph = distiller_pass.apply(graph)
return dis_graph
class SoftLabelDistillerPass(object):
def __init__(self,
student_feature_map,
teacher_feature_map,
student_temperature,
teacher_temperature,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy.
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.student_temperature = student_temperature
self.teacher_temperature = teacher_temperature
self.distillation_loss_weight = distillation_loss_weight
def apply(self, graph):
ret_graph = graph
with program_guard(ret_graph.program):
student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature
t_fea = teacher_feature_map / self.distillation_loss_weight
t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy(
s_fea, t_fea, soft_label=True)
distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes[
'soft_label_loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
...@@ -33,10 +33,17 @@ distillers: ...@@ -33,10 +33,17 @@ distillers:
teacher_feature_map: 'teacher.tmp_2' teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_2' student_feature_map: 'student.tmp_2'
distillation_loss_weight: 1 distillation_loss_weight: 1
soft_label_distiller:
class: 'SoftLabelDistiller'
student_temperature: 1.0
teacher_temperature: 1.0
teacher_feature_map: 'teacher.tmp_1'
student_feature_map: 'student.tmp_1'
distillation_loss_weight: 0.001
strategies: strategies:
distillation_strategy: distillation_strategy:
class: 'DistillationStrategy' class: 'DistillationStrategy'
distillers: ['fsp_distiller', 'l2_distiller'] distillers: ['fsp_distiller', 'l2_distiller', 'soft_label_distiller']
start_epoch: 0 start_epoch: 0
end_epoch: 1 end_epoch: 1
compressor: compressor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册