From 5b3334064795a5597b86d9ad9af99de2c5c36ec0 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 7 Mar 2022 16:04:30 +0800 Subject: [PATCH] fix kldiv when stop grad is trur (#5643) --- ppocr/losses/basic_loss.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index fc64c133..b19ce57d 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -95,9 +95,15 @@ class DMLLoss(nn.Layer): self.act = None self.use_log = use_log - self.jskl_loss = KLJSLoss(mode="js") + def _kldiv(self, x, target): + eps = 1.0e-10 + loss = target * (paddle.log(target + eps) - x) + # batch mean loss + loss = paddle.sum(loss) / loss.shape[0] + return loss + def forward(self, out1, out2): if self.act is not None: out1 = self.act(out1) @@ -106,9 +112,8 @@ class DMLLoss(nn.Layer): # for recognition distillation, log is needed for feature map log_out1 = paddle.log(out1) log_out2 = paddle.log(out2) - loss = (F.kl_div( - log_out1, out2, reduction='batchmean') + F.kl_div( - log_out2, out1, reduction='batchmean')) / 2.0 + loss = ( + self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 else: # for detection distillation log is not needed loss = self.jskl_loss(out1, out2) -- GitLab