提交 e2c6bada 编写于 作者: B Bai Yifan 提交者: whs

Support dispensable student_loss in PaddleSlim distillation (#19824)

* support_dispensable_student_loss, test=develop

* add distillation test, test=develop

* fix distillation test non convergence problem, test=develop

* fix test_distillation fail problem, test=develop
上级 3f87464e
......@@ -64,7 +64,8 @@ class DistillationStrategy(Strategy):
var.stop_gradient = True
graph = context.train_graph.clone()
graph.merge(teacher)
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
if 'loss' in graph.out_nodes:
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
# step 2
for distiller in self.distillers:
......
......@@ -88,13 +88,15 @@ class L2DistillerPass(object):
layers.square(student_feature_map - teacher_feature_map))
distillation_loss = l2loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[
'l2loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
......@@ -176,12 +178,14 @@ class FSPDistillerPass(object):
losses.append(l2_loss)
distillation_loss = layers.sum(
losses) * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[
'fsp_distillation_loss'] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
def _fsp_matrix(self, fea_map_0, fea_map_1):
......@@ -261,16 +265,18 @@ class SoftLabelDistillerPass(object):
student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature
t_fea = teacher_feature_map / self.distillation_loss_weight
t_fea = teacher_feature_map / self.teacher_temperature
t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy(
s_fea, t_fea, soft_label=True)
distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[
'soft_label_loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
......@@ -410,6 +410,8 @@ class GraphWrapper(object):
target_name = graph.out_nodes['loss']
elif 'cost' in graph.out_nodes:
target_name = graph.out_nodes['cost']
else:
return None
target = graph.var(target_name)._var
# The learning rate variable may be created in other program.
# Update information in optimizer to make
......
......@@ -32,11 +32,6 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1)
endfunction()
# NOTE: TODOOOOOOOOOOO
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
# Need to figure out the root cause and then add it back
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
endif()
......
......@@ -30,15 +30,15 @@ distillers:
distillation_loss_weight: 1
l2_distiller:
class: 'L2Distiller'
teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_2'
teacher_feature_map: 'teacher.tmp_1'
student_feature_map: 'student.tmp_1'
distillation_loss_weight: 1
soft_label_distiller:
class: 'SoftLabelDistiller'
student_temperature: 1.0
teacher_temperature: 1.0
teacher_feature_map: 'teacher.tmp_1'
student_feature_map: 'student.tmp_1'
teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_2'
distillation_loss_weight: 0.001
strategies:
distillation_strategy:
......
......@@ -139,6 +139,17 @@ class TestGraphWrapper(unittest.TestCase):
feed={'image': image,
'label': label})
def test_get_optimize_graph_without_loss(self):
self.build_program()
self.eval_graph.out_nodes = {}
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
opt = fluid.optimizer.SGD(learning_rate=0.001)
train_graph = self.eval_graph.get_optimize_graph(
opt, place, self.scope, no_grad_var_names=['image'])
self.assertEquals(train_graph, None)
def test_flops(self):
self.build_program()
self.assertEquals(self.train_graph.flops(), 354624)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册