diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml new file mode 100644 index 0000000000000000000000000000000000000000..dcf0e1f25f8076f8c29fe50413e567301ba644ce --- /dev/null +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml @@ -0,0 +1,202 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Student2: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Teacher: + pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy + freeze_params: true + return_all_feats: false + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDilaDBLoss: + weight: 1.0 + model_name_pairs: + - ["Student", "Teacher"] + - ["Student2", "Teacher"] + key: maps + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + - DistillationDMLLoss: + model_name_pairs: + - ["Student", "Student2"] + maps_name: "thrink_maps" + weight: 1.0 + # act: None + model_name_pairs: ["Student", "Student2"] + key: maps + - DistillationDBLoss: + weight: 1.0 + model_name_list: ["Student", "Student2"] + # key: maps + # name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DistillationDBPostProcess + model_name: ["Student", "Student2", "Teacher"] + # key: maps + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DistillationMetric + base_metric_name: DetMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml new file mode 100644 index 0000000000000000000000000000000000000000..1159d71bf94c330e26c3009b38c5c2b4a9c96f52 --- /dev/null +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml @@ -0,0 +1,174 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Teacher: + pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy + freeze_params: true + return_all_feats: false + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDilaDBLoss: + weight: 1.0 + model_name_pairs: + - ["Student", "Teacher"] + key: maps + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + - DistillationDBLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + # key: maps + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DistillationDBPostProcess + model_name: ["Student", "Student2"] + key: head_out + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DistillationMetric + base_metric_name: DetMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml new file mode 100644 index 0000000000000000000000000000000000000000..7fe2d2e1a065b54d0e2479475f5f67ac5e38a166 --- /dev/null +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml @@ -0,0 +1,176 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Student2: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDMLLoss: + model_name_pairs: + - ["Student", "Student2"] + maps_name: "thrink_maps" + weight: 1.0 + act: "softmax" + model_name_pairs: ["Student", "Student2"] + key: maps + - DistillationDBLoss: + weight: 1.0 + model_name_list: ["Student", "Student2"] + # key: maps + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DistillationDBPostProcess + model_name: ["Student", "Student2"] + key: head_out + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DistillationMetric + base_metric_name: DetMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index fa3ceda1b747aad3c4b275611b1257bf6950f013..8306523ac1a933f0c664fc0b4cf077659cccdee3 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -54,6 +54,27 @@ class CELoss(nn.Layer): return loss +class KLJSLoss(object): + def __init__(self, mode='kl'): + assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']" + self.mode = mode + + def __call__(self, p1, p2, reduction="mean"): + + loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5)) + + if self.mode.lower() == "js": + loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5)) + loss *= 0.5 + if reduction == "mean": + loss = paddle.mean(loss, axis=[1,2]) + elif reduction=="none" or reduction is None: + return loss + else: + loss = paddle.sum(loss, axis=[1,2]) + + return loss + class DMLLoss(nn.Layer): """ DMLLoss @@ -69,17 +90,21 @@ class DMLLoss(nn.Layer): self.act = nn.Sigmoid() else: self.act = None + + self.jskl_loss = KLJSLoss(mode="js") def forward(self, out1, out2): if self.act is not None: out1 = self.act(out1) out2 = self.act(out2) - - log_out1 = paddle.log(out1) - log_out2 = paddle.log(out2) - loss = (F.kl_div( - log_out1, out2, reduction='batchmean') + F.kl_div( - log_out2, out1, reduction='batchmean')) / 2.0 + if len(out1.shape) < 2: + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = (F.kl_div( + log_out1, out2, reduction='batchmean') + F.kl_div( + log_out2, out1, reduction='batchmean')) / 2.0 + else: + loss = self.jskl_loss(out1, out2) return loss diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 54da70174cba7bf5ca35e8fbf5aa137a437ae29c..0d6fe968d0d7733200a4cfd21d779196cccaba03 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -17,7 +17,7 @@ import paddle.nn as nn from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationDMLLoss -from .distillation_loss import DistillationDistanceLoss +from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss class CombinedLoss(nn.Layer): @@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer): def forward(self, input, batch, **kargs): loss_dict = {} + loss_all = 0. for idx, loss_func in enumerate(self.loss_func): loss = loss_func(input, batch, **kargs) if isinstance(loss, paddle.Tensor): loss = {"loss_{}_{}".format(str(loss), idx): loss} weight = self.loss_weight[idx] - loss = { - "{}_{}".format(key, idx): loss[key] * weight - for key in loss - } - loss_dict.update(loss) - loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + for key in loss.keys(): + if key == "loss": + loss_all += loss[key] * weight + else: + loss_dict["{}_{}".format(key, idx)] = loss[key] + loss_dict["loss"] = loss_all return loss_dict diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 1e8aa0d8602e3ddd49913e6a572914859377ca42..75f0a773152e52c98ada5c1907f1c8cc2f72d8f3 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -14,23 +14,76 @@ import paddle import paddle.nn as nn +import numpy as np +import cv2 from .rec_ctc_loss import CTCLoss from .basic_loss import DMLLoss from .basic_loss import DistanceLoss +from .det_db_loss import DBLoss +from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss + + +def _sum_loss(loss_dict): + if "loss" in loss_dict.keys(): + return loss_dict + else: + loss_dict["loss"] = 0. + for k, value in loss_dict.items(): + if k == "loss": + continue + else: + loss_dict["loss"] += value + return loss_dict class DistillationDMLLoss(DMLLoss): """ """ - def __init__(self, model_name_pairs=[], act=None, key=None, - name="loss_dml"): + def __init__(self, + model_name_pairs=[], + act=None, + key=None, + maps_name=None, + name="dml"): super().__init__(act=act) assert isinstance(model_name_pairs, list) self.key = key - self.model_name_pairs = model_name_pairs + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs def forward(self, predicts, batch): loss_dict = dict() @@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] - loss = super().forward(out1, out2) - if isinstance(loss, dict): - for key in loss: - loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], - idx)] = loss[key] + + if self.maps_name is None: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss else: - loss_dict["{}_{}".format(self.name, idx)] = loss + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], map_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], + idx)] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict @@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss): return loss_dict +class DistillationDBLoss(DBLoss): + def __init__(self, + model_name_list=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="db", + **kwargs): + super().__init__() + self.model_name_list = model_name_list + self.name = name + self.key = None + + def forward(self, predicts, batch): + loss_dict = {} + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + + if isinstance(loss, dict): + for key in loss.keys(): + if key == "loss": + continue + name = "{}_{}_{}".format(self.name, model_name, key) + loss_dict[name] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + +class DistillationDilaDBLoss(DBLoss): + def __init__(self, + model_name_pairs=[], + key=None, + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="dila_dbloss"): + super().__init__() + self.model_name_pairs = model_name_pairs + self.name = name + self.key = key + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + stu_outs = predicts[pair[0]] + tch_outs = predicts[pair[1]] + if self.key is not None: + stu_preds = stu_outs[self.key] + tch_preds = tch_outs[self.key] + + stu_shrink_maps = stu_preds[:, 0, :, :] + stu_binary_maps = stu_preds[:, 2, :, :] + + # dilation to teacher prediction + dilation_w = np.array([[1, 1], [1, 1]]) + th_shrink_maps = tch_preds[:, 0, :, :] + th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 + dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) + for i in range(th_shrink_maps.shape[0]): + dilate_maps[i] = cv2.dilate( + th_shrink_maps[i, :, :].astype(np.uint8), dilation_w) + th_shrink_maps = paddle.to_tensor(dilate_maps) + + label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[ + 1:] + + # calculate the shrink map loss + bce_loss = self.alpha * self.bce_loss( + stu_shrink_maps, th_shrink_maps, label_shrink_mask) + loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps, + label_shrink_mask) + + # k = f"{self.name}_{pair[0]}_{pair[1]}" + k = "{}_{}_{}".format(self.name, pair[0], pair[1]) + loss_dict[k] = bce_loss + loss_binary_maps + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + class DistillationDistanceLoss(DistanceLoss): """ """ diff --git a/ppocr/metrics/det_metric.py b/ppocr/metrics/det_metric.py index 0f9e94df42bb8f31ebc79693a01968d441b16faa..d3d353042575671826da3fc56bf02ccf40dfa5d4 100644 --- a/ppocr/metrics/det_metric.py +++ b/ppocr/metrics/det_metric.py @@ -55,6 +55,7 @@ class DetMetric(object): result = self.evaluator.evaluate_image(gt_info_list, det_info_list) self.results.append(result) + def get_metric(self): """ return metrics { diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py index a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b..c440cebdd0f96493fc33000a0d304cbe5e3f0624 100644 --- a/ppocr/metrics/distillation_metric.py +++ b/ppocr/metrics/distillation_metric.py @@ -24,8 +24,8 @@ from .cls_metric import ClsMetric class DistillationMetric(object): def __init__(self, key=None, - base_metric_name="RecMetric", - main_indicator='acc', + base_metric_name=None, + main_indicator=None, **kwargs): self.main_indicator = main_indicator self.key = key @@ -42,16 +42,13 @@ class DistillationMetric(object): main_indicator=self.main_indicator, **self.kwargs) self.metrics[key].reset() - def __call__(self, preds, *args, **kwargs): + def __call__(self, preds, batch, **kwargs): assert isinstance(preds, dict) if self.metrics is None: self._init_metrcis(preds) output = dict() for key in preds: - metric = self.metrics[key].__call__(preds[key], *args, **kwargs) - for sub_key in metric: - output["{}_{}".format(key, sub_key)] = metric[sub_key] - return output + self.metrics[key].__call__(preds[key], batch, **kwargs) def get_metric(self): """ diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 03fbcee8465df9c8bb7845ea62fc0ac04917caa0..dbd18070b36f7e99c62de94048ab53d1bedcebe0 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -79,7 +79,10 @@ class BaseModel(nn.Layer): x = self.neck(x) y["neck_out"] = x x = self.head(x, targets=data) - y["head_out"] = x + if isinstance(x, dict): + y.update(x) + else: + y["head_out"] = x if self.return_all_feats: return y else: diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 2e512331afcfc20e422dbef4ba1a4acd581df9e7..2b1d3aae3b7303a61b20db15df5ce4bd9bb7b235 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head from .base_model import BaseModel -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_pretrained_params __all__ = ['DistillationModel'] @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - init_model(model, path=pretrained) + model = load_pretrained_params(model, pretrained) if freeze_params: for param in model.parameters(): param.trainable = False diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 2f5bdc3b13135ed69e8af2e28ee0cd8042bf87e6..654ddf39d23590fbaf7f7b9b57f38cc86a1b6669 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -21,7 +21,7 @@ import copy __all__ = ['build_post_process'] -from .db_postprocess import DBPostProcess +from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ @@ -34,7 +34,8 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode' + 'DistillationCTCLabelDecode', 'TableLabelDecode', + 'DistillationDBPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 769ddbe23253ce58e2bccd46ef5074cc2a7d27da..d9c9869dfcd35cb9b491db826f3bff5f766723f4 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -187,3 +187,29 @@ class DBPostProcess(object): boxes_batch.append({'points': boxes}) return boxes_batch + + +class DistillationDBPostProcess(object): + def __init__(self, model_name=["student"], + key=None, + thresh=0.3, + box_thresh=0.6, + max_candidates=1000, + unclip_ratio=1.5, + use_dilation=False, + score_mode="fast", + **kwargs): + self.model_name = model_name + self.key = key + self.post_process = DBPostProcess(thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) + + def __call__(self, predicts, shape_list): + results = {} + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) + return results diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 1d760e983a635dcc6b48b839ee99434c67b4378d..3bb022ed98b140995b79ceea93d7f494d3f5930d 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer): logger.info(f"loaded pretrained_model successful from {pm}") return {} +def load_pretrained_params(model, path): + if path is None: + return False + if not os.path.exists(path) and not os.path.exists(path + ".pdparams"): + print(f"The pretrained_model {path} does not exists!") + return False + + path = path if path.endswith('.pdparams') else path + '.pdparams' + params = paddle.load(path) + state_dict = model.state_dict() + new_state_dict = {} + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + print( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) + model.set_state_dict(new_state_dict) + print(f"load pretrain successful from {path}") + return model def save_model(model, optimizer, diff --git a/tools/eval.py b/tools/eval.py index c1315805b5ff9bf29dee87a21688a145b4662b9a..0120baab0f34d5fadbbf4df20d92d6b62dd176a2 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,7 @@ from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_pretrained_params from ppocr.utils.utility import print_dict import tools.program as program @@ -55,7 +55,10 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = config['Architecture']['model_type'] + if "model_type" in config['Architecture'].keys(): + model_type = config['Architecture']['model_type'] + else: + model_type = None best_model_dict = init_model(config, model) if len(best_model_dict): @@ -68,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/program.py b/tools/program.py index 2d99f2968a3f0c8acc359ed0fbb199650bd7010c..595fe4cb96c0379b1a33504e0ebdd85e70086340 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,7 +186,10 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = config['Architecture']['model_type'] + try: + model_type = config['Architecture']['model_type'] + except: + model_type = None if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] diff --git a/tools/train.py b/tools/train.py index 20f5a670d5c8e666678259e0042b3b790e528590..05d295aa99718c25b94a123c23d08c2904fe8c6a 100755 --- a/tools/train.py +++ b/tools/train.py @@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer): eval_class = build_metric(config['Metric']) # load pretrain model pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) - logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format(