diff --git a/python/paddle/fluid/contrib/slim/distillation/distiller.py b/python/paddle/fluid/contrib/slim/distillation/distiller.py index eda7954a2f1d8e3364a14b0d6ccb81fcbf5d489f..f08e0bcfefc01c4dc4663843e55faaf869e32cbb 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distiller.py +++ b/python/paddle/fluid/contrib/slim/distillation/distiller.py @@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object): student_feature_map = ret_graph.var(self.student_feature_map)._var teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var - s_fea = student_feature_map / self.student_temperature - t_fea = teacher_feature_map / self.teacher_temperature + s_fea = layers.softmax(student_feature_map / + self.student_temperature) + t_fea = layers.softmax(teacher_feature_map / + self.teacher_temperature) t_fea.stop_gradient = True - ce_loss = layers.softmax_with_cross_entropy( - s_fea, t_fea, soft_label=True) + ce_loss = layres.reduce_mean( + layers.cross_entropy( + s_fea, t_fea, soft_label=True)) distillation_loss = ce_loss * self.distillation_loss_weight student_loss = 0 if 'loss' in ret_graph.out_nodes: