From a479ca67a461db9794758a7c6aa2359afc299f26 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 1 Jul 2022 18:23:07 +0800 Subject: [PATCH] fix kljs loss --- .../det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml | 2 +- .../det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml | 2 +- ppocr/losses/basic_loss.py | 20 ++++++++++++++----- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml index f3ad9664..8b160f63 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml @@ -65,7 +65,7 @@ Loss: - ["Student", "Teacher"] maps_name: "thrink_maps" weight: 1.0 - act: "softmax" + # act: None model_name_pairs: ["Student", "Teacher"] key: maps - DistillationDBLoss: diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml index 3f30ada1..c85fc4b7 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml @@ -60,7 +60,7 @@ Loss: - ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0 - act: "softmax" + # act: None model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 2df96ea2..a0ab10fb 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -57,17 +57,27 @@ class CELoss(nn.Layer): class KLJSLoss(object): def __init__(self, mode='kl'): assert mode in ['kl', 'js', 'KL', 'JS' - ], "mode can only be one of ['kl', 'js', 'KL', 'JS']" + ], "mode can only be one of ['kl', 'KL', 'js', 'JS']" self.mode = mode def __call__(self, p1, p2, reduction="mean"): - loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) - - if self.mode.lower() == "js": + if self.mode.lower() == 'kl': + loss = paddle.multiply(p2, + paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) loss += paddle.multiply( p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) loss *= 0.5 + elif self.mode.lower() == "js": + loss = paddle.multiply( + p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + loss += paddle.multiply( + p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + loss *= 0.5 + else: + raise ValueError( + "The mode.lower() if KLJSLoss should be one of ['kl', 'js']") + if reduction == "mean": loss = paddle.mean(loss, axis=[1, 2]) elif reduction == "none" or reduction is None: @@ -95,7 +105,7 @@ class DMLLoss(nn.Layer): self.act = None self.use_log = use_log - self.jskl_loss = KLJSLoss(mode="js") + self.jskl_loss = KLJSLoss(mode="kl") def _kldiv(self, x, target): eps = 1.0e-10 -- GitLab