From e2c6bada36d51b028f5aa8d8da275c2bf6bb8b56 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 18 Sep 2019 15:08:24 +0800 Subject: [PATCH] Support dispensable student_loss in PaddleSlim distillation (#19824) * support_dispensable_student_loss, test=develop * add distillation test, test=develop * fix distillation test non convergence problem, test=develop * fix test_distillation fail problem, test=develop --- .../distillation/distillation_strategy.py | 3 ++- .../contrib/slim/distillation/distiller.py | 20 ++++++++++++------- .../fluid/contrib/slim/graph/graph_wrapper.py | 2 ++ .../fluid/contrib/slim/tests/CMakeLists.txt | 5 ----- .../slim/tests/distillation/compress.yaml | 8 ++++---- .../contrib/slim/tests/test_graph_wrapper.py | 11 ++++++++++ ....py => test_slim_distillation_strategy.py} | 0 7 files changed, 32 insertions(+), 17 deletions(-) rename python/paddle/fluid/contrib/slim/tests/{test_distillation_strategy.py => test_slim_distillation_strategy.py} (100%) diff --git a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py index 42389079f8..c54e5dc5b5 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py +++ b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py @@ -64,7 +64,8 @@ class DistillationStrategy(Strategy): var.stop_gradient = True graph = context.train_graph.clone() graph.merge(teacher) - graph.out_nodes['student_loss'] = graph.out_nodes['loss'] + if 'loss' in graph.out_nodes: + graph.out_nodes['student_loss'] = graph.out_nodes['loss'] # step 2 for distiller in self.distillers: diff --git a/python/paddle/fluid/contrib/slim/distillation/distiller.py b/python/paddle/fluid/contrib/slim/distillation/distiller.py index 3dccfa7e98..eda7954a2f 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distiller.py +++ b/python/paddle/fluid/contrib/slim/distillation/distiller.py @@ -88,13 +88,15 @@ class L2DistillerPass(object): layers.square(student_feature_map - teacher_feature_map)) distillation_loss = l2loss * self.distillation_loss_weight - student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var + student_loss = 0 + if 'loss' in ret_graph.out_nodes: + student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var loss = distillation_loss + student_loss + ret_graph.out_nodes['loss'] = loss.name ret_graph.out_nodes[ 'l2loss_' + self.student_feature_map + "_" + self.teacher_feature_map] = distillation_loss.name - ret_graph.out_nodes['loss'] = loss.name return ret_graph @@ -176,12 +178,14 @@ class FSPDistillerPass(object): losses.append(l2_loss) distillation_loss = layers.sum( losses) * self.distillation_loss_weight - student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var + student_loss = 0 + if 'loss' in ret_graph.out_nodes: + student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var loss = distillation_loss + student_loss + ret_graph.out_nodes['loss'] = loss.name ret_graph.out_nodes[ 'fsp_distillation_loss'] = distillation_loss.name - ret_graph.out_nodes['loss'] = loss.name return ret_graph def _fsp_matrix(self, fea_map_0, fea_map_1): @@ -261,16 +265,18 @@ class SoftLabelDistillerPass(object): student_feature_map = ret_graph.var(self.student_feature_map)._var teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var s_fea = student_feature_map / self.student_temperature - t_fea = teacher_feature_map / self.distillation_loss_weight + t_fea = teacher_feature_map / self.teacher_temperature t_fea.stop_gradient = True ce_loss = layers.softmax_with_cross_entropy( s_fea, t_fea, soft_label=True) distillation_loss = ce_loss * self.distillation_loss_weight - student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var + student_loss = 0 + if 'loss' in ret_graph.out_nodes: + student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var loss = distillation_loss + student_loss + ret_graph.out_nodes['loss'] = loss.name ret_graph.out_nodes[ 'soft_label_loss_' + self.student_feature_map + "_" + self.teacher_feature_map] = distillation_loss.name - ret_graph.out_nodes['loss'] = loss.name return ret_graph diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index fd248fc6f6..3ed07a287b 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -410,6 +410,8 @@ class GraphWrapper(object): target_name = graph.out_nodes['loss'] elif 'cost' in graph.out_nodes: target_name = graph.out_nodes['cost'] + else: + return None target = graph.var(target_name)._var # The learning rate variable may be created in other program. # Update information in optimizer to make diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index ecd717881f..037c6716d9 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -32,11 +32,6 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn --acc_diff_threshold 0.1) endfunction() -# NOTE: TODOOOOOOOOOOO -# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs -# Need to figure out the root cause and then add it back -list(REMOVE_ITEM TEST_OPS test_distillation_strategy) - if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml b/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml index 07ccb7a21d..0d3d10b865 100644 --- a/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml +++ b/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml @@ -30,15 +30,15 @@ distillers: distillation_loss_weight: 1 l2_distiller: class: 'L2Distiller' - teacher_feature_map: 'teacher.tmp_2' - student_feature_map: 'student.tmp_2' + teacher_feature_map: 'teacher.tmp_1' + student_feature_map: 'student.tmp_1' distillation_loss_weight: 1 soft_label_distiller: class: 'SoftLabelDistiller' student_temperature: 1.0 teacher_temperature: 1.0 - teacher_feature_map: 'teacher.tmp_1' - student_feature_map: 'student.tmp_1' + teacher_feature_map: 'teacher.tmp_2' + student_feature_map: 'student.tmp_2' distillation_loss_weight: 0.001 strategies: distillation_strategy: diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py index 7d190ce016..5340f36196 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py @@ -139,6 +139,17 @@ class TestGraphWrapper(unittest.TestCase): feed={'image': image, 'label': label}) + def test_get_optimize_graph_without_loss(self): + self.build_program() + self.eval_graph.out_nodes = {} + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + opt = fluid.optimizer.SGD(learning_rate=0.001) + train_graph = self.eval_graph.get_optimize_graph( + opt, place, self.scope, no_grad_var_names=['image']) + self.assertEquals(train_graph, None) + def test_flops(self): self.build_program() self.assertEquals(self.train_graph.flops(), 354624) diff --git a/python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_slim_distillation_strategy.py similarity index 100% rename from python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py rename to python/paddle/fluid/contrib/slim/tests/test_slim_distillation_strategy.py -- GitLab