diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index acf438950a43af3356c7ab0aadf956fdf226814e..6167b6e13f9b75b87890ba008de5dd216b18917e 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -94,14 +94,11 @@ Loss: - ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0 - # act: None model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] - # key: maps - # name: DBLoss balance_loss: true main_loss_type: DiceLoss alpha: 5 @@ -191,7 +188,6 @@ Eval: channel_first: False - DetLabelEncode: # Class handling label - DetResizeForTest: -# image_shape: [736, 1280] - NormalizeImage: scale: 1./255. mean: [0.485, 0.456, 0.406] diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml index ef58befd694e26704c734d7fd072ebc3370c8554..88514e76a501fc9fac887cb170eb870523b31b8e 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml @@ -24,6 +24,7 @@ Architecture: model_type: det Models: Student: + pretrained: model_type: det algorithm: DB Transform: null @@ -40,6 +41,7 @@ Architecture: name: DBHead k: 50 Student2: + pretrained: model_type: det algorithm: DB Transform: null @@ -56,6 +58,7 @@ Architecture: name: DBHead k: 50 Teacher: + pretrained: freeze_params: true return_all_feats: false model_type: det @@ -91,14 +94,11 @@ Loss: - ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0 - # act: None model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] - # key: maps - # name: DBLoss balance_loss: true main_loss_type: DiceLoss alpha: 5 @@ -204,31 +204,21 @@ Eval: label_file_list: - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt transforms: - - DecodeImage: - img_mode: BGR - channel_first: false - - DetLabelEncode: null - - DetResizeForTest: null - - NormalizeImage: - scale: 1./255. - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - order: hwc - - ToCHWImage: null - - KeepKeys: - keep_keys: - - image - - shape - - polys - - ignore_tags + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] loader: - shuffle: false - drop_last: false - batch_size_per_card: 1 + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 num_workers: 2 diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index da9faa08bc5ca35c5d65f7a7bfbbdd67192f052b..a6f0472ecd0cf3f443aeb474ca6dd5487111f8f0 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -60,19 +60,19 @@ class KLJSLoss(object): ], "mode can only be one of ['kl', 'KL', 'js', 'JS']" self.mode = mode - def __call__(self, p1, p2, reduction="mean"): + def __call__(self, p1, p2, reduction="mean", eps=1e-5): if self.mode.lower() == 'kl': loss = paddle.multiply(p2, - paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) + paddle.log((p2 + eps) / (p1 + eps) + eps)) loss += paddle.multiply( - p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) + p1, paddle.log((p1 + eps) / (p2 + eps) + eps)) loss *= 0.5 elif self.mode.lower() == "js": loss = paddle.multiply( - p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps)) loss += paddle.multiply( - p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps)) loss *= 0.5 else: raise ValueError( @@ -125,7 +125,7 @@ class DMLLoss(nn.Layer): loss = ( self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 else: - # for detection distillation log is not needed + # distillation log is not needed for detection loss = self.jskl_loss(out1, out2) return loss