未验证 提交 5b333406 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix kldiv when stop grad is trur (#5643)

上级 db608932
...@@ -95,9 +95,15 @@ class DMLLoss(nn.Layer): ...@@ -95,9 +95,15 @@ class DMLLoss(nn.Layer):
self.act = None self.act = None
self.use_log = use_log self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="js")
def _kldiv(self, x, target):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
# batch mean loss
loss = paddle.sum(loss) / loss.shape[0]
return loss
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
...@@ -106,9 +112,8 @@ class DMLLoss(nn.Layer): ...@@ -106,9 +112,8 @@ class DMLLoss(nn.Layer):
# for recognition distillation, log is needed for feature map # for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (
log_out1, out2, reduction='batchmean') + F.kl_div( self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
log_out2, out1, reduction='batchmean')) / 2.0
else: else:
# for detection distillation log is not needed # for detection distillation log is not needed
loss = self.jskl_loss(out1, out2) loss = self.jskl_loss(out1, out2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册