From 97cb2d378eef9cc73447edfe23531f45c377ec46 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Mon, 29 Aug 2022 13:56:20 +0800 Subject: [PATCH] fix doc --- .../det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml | 39 ++++++------------- ppocr/losses/basic_loss.py | 14 +++---- 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml index 3e77577c..f4824b68 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml @@ -24,6 +24,7 @@ Architecture: model_type: det Models: Student: + pretrained: model_type: det algorithm: DB Transform: null @@ -40,6 +41,7 @@ Architecture: name: DBHead k: 50 Student2: + pretrained: model_type: det algorithm: DB Transform: null @@ -56,6 +58,7 @@ Architecture: name: DBHead k: 50 Teacher: + pretrained: freeze_params: true return_all_feats: false model_type: det @@ -91,14 +94,11 @@ Loss: - ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0 - # act: None model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] - # key: maps - # name: DBLoss balance_loss: true main_loss_type: DiceLoss alpha: 5 @@ -204,31 +204,16 @@ Eval: label_file_list: - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt transforms: - - DecodeImage: + - DecodeImage: # load image img_mode: BGR - channel_first: false - - DetLabelEncode: null - - DetResizeForTest: null + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: - NormalizeImage: scale: 1./255. - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - order: hwc - - ToCHWImage: null + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: - KeepKeys: - keep_keys: - - image - - shape - - polys - - ignore_tags - loader: - shuffle: false - drop_last: false - batch_size_per_card: 1 - num_workers: 2 + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index a0ab10fb..028deae6 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -60,19 +60,19 @@ class KLJSLoss(object): ], "mode can only be one of ['kl', 'KL', 'js', 'JS']" self.mode = mode - def __call__(self, p1, p2, reduction="mean"): + def __call__(self, p1, p2, reduction="mean", eps=1e-5): if self.mode.lower() == 'kl': loss = paddle.multiply(p2, - paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) - loss += paddle.multiply( - p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) + paddle.log((p2 + eps) / (p1 + eps) + eps)) + loss += paddle.multiply(p1, + paddle.log((p1 + eps) / (p2 + eps) + eps)) loss *= 0.5 elif self.mode.lower() == "js": loss = paddle.multiply( - p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps)) loss += paddle.multiply( - p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps)) loss *= 0.5 else: raise ValueError( @@ -125,7 +125,7 @@ class DMLLoss(nn.Layer): loss = ( self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 else: - # for detection distillation log is not needed + # log is not needed for detection loss = self.jskl_loss(out1, out2) return loss -- GitLab