提交 e2c6bada 编写于 作者: B Bai Yifan 提交者: whs

Support dispensable student_loss in PaddleSlim distillation (#19824)

* support_dispensable_student_loss, test=develop

* add distillation test, test=develop

* fix distillation test non convergence problem, test=develop

* fix test_distillation fail problem, test=develop
上级 3f87464e
...@@ -64,7 +64,8 @@ class DistillationStrategy(Strategy): ...@@ -64,7 +64,8 @@ class DistillationStrategy(Strategy):
var.stop_gradient = True var.stop_gradient = True
graph = context.train_graph.clone() graph = context.train_graph.clone()
graph.merge(teacher) graph.merge(teacher)
graph.out_nodes['student_loss'] = graph.out_nodes['loss'] if 'loss' in graph.out_nodes:
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
# step 2 # step 2
for distiller in self.distillers: for distiller in self.distillers:
......
...@@ -88,13 +88,15 @@ class L2DistillerPass(object): ...@@ -88,13 +88,15 @@ class L2DistillerPass(object):
layers.square(student_feature_map - teacher_feature_map)) layers.square(student_feature_map - teacher_feature_map))
distillation_loss = l2loss * self.distillation_loss_weight distillation_loss = l2loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[ ret_graph.out_nodes[
'l2loss_' + self.student_feature_map + "_" + 'l2loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph return ret_graph
...@@ -176,12 +178,14 @@ class FSPDistillerPass(object): ...@@ -176,12 +178,14 @@ class FSPDistillerPass(object):
losses.append(l2_loss) losses.append(l2_loss)
distillation_loss = layers.sum( distillation_loss = layers.sum(
losses) * self.distillation_loss_weight losses) * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[ ret_graph.out_nodes[
'fsp_distillation_loss'] = distillation_loss.name 'fsp_distillation_loss'] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph return ret_graph
def _fsp_matrix(self, fea_map_0, fea_map_1): def _fsp_matrix(self, fea_map_0, fea_map_1):
...@@ -261,16 +265,18 @@ class SoftLabelDistillerPass(object): ...@@ -261,16 +265,18 @@ class SoftLabelDistillerPass(object):
student_feature_map = ret_graph.var(self.student_feature_map)._var student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature s_fea = student_feature_map / self.student_temperature
t_fea = teacher_feature_map / self.distillation_loss_weight t_fea = teacher_feature_map / self.teacher_temperature
t_fea.stop_gradient = True t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy( ce_loss = layers.softmax_with_cross_entropy(
s_fea, t_fea, soft_label=True) s_fea, t_fea, soft_label=True)
distillation_loss = ce_loss * self.distillation_loss_weight distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var student_loss = 0
if 'loss' in ret_graph.out_nodes:
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss loss = distillation_loss + student_loss
ret_graph.out_nodes['loss'] = loss.name
ret_graph.out_nodes[ ret_graph.out_nodes[
'soft_label_loss_' + self.student_feature_map + "_" + 'soft_label_loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph return ret_graph
...@@ -410,6 +410,8 @@ class GraphWrapper(object): ...@@ -410,6 +410,8 @@ class GraphWrapper(object):
target_name = graph.out_nodes['loss'] target_name = graph.out_nodes['loss']
elif 'cost' in graph.out_nodes: elif 'cost' in graph.out_nodes:
target_name = graph.out_nodes['cost'] target_name = graph.out_nodes['cost']
else:
return None
target = graph.var(target_name)._var target = graph.var(target_name)._var
# The learning rate variable may be created in other program. # The learning rate variable may be created in other program.
# Update information in optimizer to make # Update information in optimizer to make
......
...@@ -32,11 +32,6 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn ...@@ -32,11 +32,6 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1)
endfunction() endfunction()
# NOTE: TODOOOOOOOOOOO
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
# Need to figure out the root cause and then add it back
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
endif() endif()
......
...@@ -30,15 +30,15 @@ distillers: ...@@ -30,15 +30,15 @@ distillers:
distillation_loss_weight: 1 distillation_loss_weight: 1
l2_distiller: l2_distiller:
class: 'L2Distiller' class: 'L2Distiller'
teacher_feature_map: 'teacher.tmp_2' teacher_feature_map: 'teacher.tmp_1'
student_feature_map: 'student.tmp_2' student_feature_map: 'student.tmp_1'
distillation_loss_weight: 1 distillation_loss_weight: 1
soft_label_distiller: soft_label_distiller:
class: 'SoftLabelDistiller' class: 'SoftLabelDistiller'
student_temperature: 1.0 student_temperature: 1.0
teacher_temperature: 1.0 teacher_temperature: 1.0
teacher_feature_map: 'teacher.tmp_1' teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_1' student_feature_map: 'student.tmp_2'
distillation_loss_weight: 0.001 distillation_loss_weight: 0.001
strategies: strategies:
distillation_strategy: distillation_strategy:
......
...@@ -139,6 +139,17 @@ class TestGraphWrapper(unittest.TestCase): ...@@ -139,6 +139,17 @@ class TestGraphWrapper(unittest.TestCase):
feed={'image': image, feed={'image': image,
'label': label}) 'label': label})
def test_get_optimize_graph_without_loss(self):
self.build_program()
self.eval_graph.out_nodes = {}
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
opt = fluid.optimizer.SGD(learning_rate=0.001)
train_graph = self.eval_graph.get_optimize_graph(
opt, place, self.scope, no_grad_var_names=['image'])
self.assertEquals(train_graph, None)
def test_flops(self): def test_flops(self):
self.build_program() self.build_program()
self.assertEquals(self.train_graph.flops(), 354624) self.assertEquals(self.train_graph.flops(), 354624)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册