From 48898ac357f4b242b98276c5bf48984eb60b5833 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 6 Jul 2021 08:13:13 +0000 Subject: [PATCH] add config --- .../ch_det_lite_train_distill_v2.1.yml | 194 ++++++++++++++++++ ppocr/losses/distillation_loss.py | 90 ++++++++ ppocr/modeling/architectures/base_model.py | 2 +- 3 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml new file mode 100644 index 00000000..54ef12a1 --- /dev/null +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml @@ -0,0 +1,194 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Student2: + pretrained: + freeze_params: false + return_all_feats: false + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + Teacher: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDilaDBLoss: + weight: 1.0 + model_name_list: ["Student", "Student2", "Teacher"] + key: maps + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + - DistillationDMLLoss: + maps_name: ["thrink_maps"] + weight: 1.0 + act: "softmax" + model_name_pairs: ["Student", "Student2"] + key: maps + - DistillationDBLoss: + model_name_list: ["Student", "Teacher"] + key: maps + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DistillationCTDBPostProcessCLabelDecode + model_name: ["Student", "Student2"] + key: head_out + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DistillationMetric + base_metric_name: DetMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index b19f3f89..421bbaba 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -132,6 +132,96 @@ class DistillationCTCLoss(CTCLoss): return loss_dict +class DistillationDBLoss(DBLoss): + def __init__(self, + model_name_list=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="db_loss", + **kwargs): + super().__init__() + self.model_name_list = model_name_list + self.name = name + self.key = None + + def forward(self, preicts, batch): + loss_dict = {} + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + + if isinstance(loss, dict): + for key in loss.keys(): + if key == "loss": + continue + name = "{}_{}_{}".format(self.name, model_name, key) + loss_dict[name] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + +class DistillationDilaDBLoss(DBLoss): + def __init__(self, + model_name_pairs=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="dila_dbloss"): + super().__init__() + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + stu_outs = predicts[pair[0]] + tch_outs = predicts[pair[1]] + if self.key is not None: + stu_preds = stu_outs[self.key] + tch_preds = tch_outs[self.key] + + stu_shrink_maps = stu_preds[:, 0, :, :] + stu_binary_maps = stu_preds[:, 2, :, :] + + # dilation to teacher prediction + dilation_w = np.array([[1, 1], [1, 1]]) + th_shrink_maps = tch_preds[:, 0, :, :] + th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 + dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) + for i in range(th_shrink_maps.shape[0]): + dilate_maps[i] = cv2.dilate( + th_shrink_maps[i, :, :].astype(np.uint8), dilation_w) + th_shrink_maps = paddle.to_tensor(dilate_maps) + + label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[ + 1:] + + # calculate the shrink map loss + bce_loss = self.alpha * self.bce_loss( + stu_shrink_maps, th_shrink_maps, label_shrink_mask) + loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps, + label_shrink_mask) + + # k = f"{self.name}_{pair[0]}_{pair[1]}" + k = "{}_{}_{}".format(self.name, pair[0], pair[1]) + loss_dict[k] = bce_loss + loss_binary_maps + + loss_dict = _sum_loss(loss_dict) + return loss + + class DistillationDistanceLoss(DistanceLoss): """ """ diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index ff3da01a..dbd18070 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -79,7 +79,7 @@ class BaseModel(nn.Layer): x = self.neck(x) y["neck_out"] = x x = self.head(x, targets=data) - if type(x) is dict: + if isinstance(x, dict): y.update(x) else: y["head_out"] = x -- GitLab