diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index acf438950a43af3356c7ab0aadf956fdf226814e..0c6ab2a0d1d9733d647dc40a7b182fe201866a78 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -191,7 +191,6 @@ Eval: channel_first: False - DetLabelEncode: # Class handling label - DetResizeForTest: -# image_shape: [736, 1280] - NormalizeImage: scale: 1./255. mean: [0.485, 0.456, 0.406] diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml index ef58befd694e26704c734d7fd072ebc3370c8554..000d95e892cb8e6dcceeb7c22264c28934d1000c 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml @@ -24,6 +24,7 @@ Architecture: model_type: det Models: Student: + pretrained: model_type: det algorithm: DB Transform: null @@ -40,6 +41,7 @@ Architecture: name: DBHead k: 50 Student2: + pretrained: model_type: det algorithm: DB Transform: null @@ -91,14 +93,11 @@ Loss: - ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0 - # act: None model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] - # key: maps - # name: DBLoss balance_loss: true main_loss_type: DiceLoss alpha: 5 @@ -197,6 +196,7 @@ Train: drop_last: false batch_size_per_card: 8 num_workers: 4 + Eval: dataset: name: SimpleDataSet @@ -204,31 +204,21 @@ Eval: label_file_list: - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt transforms: - - DecodeImage: - img_mode: BGR - channel_first: false - - DetLabelEncode: null - - DetResizeForTest: null - - NormalizeImage: - scale: 1./255. - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - order: hwc - - ToCHWImage: null - - KeepKeys: - keep_keys: - - image - - shape - - polys - - ignore_tags + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] loader: - shuffle: false - drop_last: false - batch_size_per_card: 1 - num_workers: 2 + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index da9faa08bc5ca35c5d65f7a7bfbbdd67192f052b..58410b4db2157074c2cb0f7db590c84021e10ace 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -60,19 +60,19 @@ class KLJSLoss(object): ], "mode can only be one of ['kl', 'KL', 'js', 'JS']" self.mode = mode - def __call__(self, p1, p2, reduction="mean"): + def __call__(self, p1, p2, reduction="mean", eps=1e-5): if self.mode.lower() == 'kl': loss = paddle.multiply(p2, - paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) - loss += paddle.multiply( - p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) + paddle.log((p2 + eps) / (p1 + eps) + eps)) + loss += paddle.multiply(p1, + paddle.log((p1 + eps) / (p2 + eps) + eps)) loss *= 0.5 elif self.mode.lower() == "js": loss = paddle.multiply( - p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps)) loss += paddle.multiply( - p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps)) loss *= 0.5 else: raise ValueError(