未验证 提交 ffec9195 编写于 作者: B Bai Yifan 提交者: GitHub

soft_label_distiller fix, test=develop (#20645)

上级 003f369b
...@@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object): ...@@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object):
student_feature_map = ret_graph.var(self.student_feature_map)._var student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature s_fea = layers.softmax(student_feature_map /
t_fea = teacher_feature_map / self.teacher_temperature self.student_temperature)
t_fea = layers.softmax(teacher_feature_map /
self.teacher_temperature)
t_fea.stop_gradient = True t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy( ce_loss = layres.reduce_mean(
s_fea, t_fea, soft_label=True) layers.cross_entropy(
s_fea, t_fea, soft_label=True))
distillation_loss = ce_loss * self.distillation_loss_weight distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = 0 student_loss = 0
if 'loss' in ret_graph.out_nodes: if 'loss' in ret_graph.out_nodes:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册