From e083f14972353c49d836216446dc3ea9808379b3 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 16 Oct 2019 21:00:49 +0800 Subject: [PATCH] soft_label_distiller fix, test=release/1.6 (#20657) --- .../fluid/contrib/slim/distillation/distiller.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/distillation/distiller.py b/python/paddle/fluid/contrib/slim/distillation/distiller.py index eda7954a2f..f08e0bcfef 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distiller.py +++ b/python/paddle/fluid/contrib/slim/distillation/distiller.py @@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object): student_feature_map = ret_graph.var(self.student_feature_map)._var teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var - s_fea = student_feature_map / self.student_temperature - t_fea = teacher_feature_map / self.teacher_temperature + s_fea = layers.softmax(student_feature_map / + self.student_temperature) + t_fea = layers.softmax(teacher_feature_map / + self.teacher_temperature) t_fea.stop_gradient = True - ce_loss = layers.softmax_with_cross_entropy( - s_fea, t_fea, soft_label=True) + ce_loss = layres.reduce_mean( + layers.cross_entropy( + s_fea, t_fea, soft_label=True)) distillation_loss = ce_loss * self.distillation_loss_weight student_loss = 0 if 'loss' in ret_graph.out_nodes: -- GitLab