diff --git a/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml b/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml new file mode 100644 index 0000000000000000000000000000000000000000..ab622d122cf8bca035b289dfc3d701aa948537de --- /dev/null +++ b/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml @@ -0,0 +1,172 @@ +Global: + debug: false + use_gpu: true + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 100 + save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/ + save_epoch_step: 10 + eval_batch_step: + - 0 + - 1500 + cal_metric_during_train: false + checkpoints: + pretrained_model: + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./checkpoints/det_db/predicts_db.txt + distributed: true + +Architecture: + model_type: det + algorithm: DB + Transform: null + Backbone: + name: LCNetv3 + scale: 0.75 + det: True + Neck: + name: RSEFPN + out_channels: 96 + shortcut: True + Head: + name: CBNHeadLocal + k: 50 + mode: "small" + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 #(8*8c) + warmup_epoch: 2 + regularizer: + name: L2 + factor: 5.0e-05 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - CopyPaste: null + - IaaAugment: + augmenter_args: + - type: Fliplr + args: + p: 0.5 + - type: Affine + args: + rotate: + - -10 + - 10 + - type: Resize + args: + size: + - 0.5 + - 3 + - EastRandomCropData: + size: + - 640 + - 640 + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + total_epoch: 500 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + total_epoch: 500 + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - threshold_map + - threshold_mask + - shrink_map + - shrink_mask + loader: + shuffle: true + drop_last: false + batch_size_per_card: 8 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - DetResizeForTest: + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - shape + - polys + - ignore_tags + loader: + shuffle: false + drop_last: false + batch_size_per_card: 1 + num_workers: 2 +profiler_options: null diff --git a/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml b/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml new file mode 100644 index 0000000000000000000000000000000000000000..74581a8bc6259a4b0f21d75fa0dc28a1ac323308 --- /dev/null +++ b/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml @@ -0,0 +1,172 @@ +Global: + debug: false + use_gpu: true + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 100 + save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/ + save_epoch_step: 10 + eval_batch_step: + - 0 + - 1500 + cal_metric_during_train: false + checkpoints: + pretrained_model: + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./checkpoints/det_db/predicts_db.txt + distributed: true + +Architecture: + model_type: det + algorithm: DB + Transform: null + Backbone: + name: PPHGNet_small + det: True + Neck: + name: LKPAN + out_channels: 256 + intracl: true + Head: + name: CBNHeadLocal + k: 50 + mode: "large" + + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 #(8*8c) + warmup_epoch: 2 + regularizer: + name: L2 + factor: 1e-6 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - CopyPaste: null + - IaaAugment: + augmenter_args: + - type: Fliplr + args: + p: 0.5 + - type: Affine + args: + rotate: + - -10 + - 10 + - type: Resize + args: + size: + - 0.5 + - 3 + - EastRandomCropData: + size: + - 640 + - 640 + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + total_epoch: 500 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + total_epoch: 500 + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - threshold_map + - threshold_mask + - shrink_map + - shrink_mask + loader: + shuffle: true + drop_last: false + batch_size_per_card: 8 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - DetResizeForTest: + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - shape + - polys + - ignore_tags + loader: + shuffle: false + drop_last: false + batch_size_per_card: 1 + num_workers: 2 +profiler_options: null diff --git a/ppocr/data/imaug/make_border_map.py b/ppocr/data/imaug/make_border_map.py index abab38368db2de84e54b060598fc509a65219296..03b7817cfbe2068184981b18a7aa539c8d350e3b 100644 --- a/ppocr/data/imaug/make_border_map.py +++ b/ppocr/data/imaug/make_border_map.py @@ -44,6 +44,10 @@ class MakeBorderMap(object): self.shrink_ratio = shrink_ratio self.thresh_min = thresh_min self.thresh_max = thresh_max + if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[ + 'epoch'] != "None": + self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[ + 'epoch'] / float(kwargs['total_epoch']) def __call__(self, data): diff --git a/ppocr/data/imaug/make_shrink_map.py b/ppocr/data/imaug/make_shrink_map.py index 6c65c20e5621f91a5b1fba549b059c92923fca6f..d0317b61fe05ce75c479a2485cef540742f489e0 100644 --- a/ppocr/data/imaug/make_shrink_map.py +++ b/ppocr/data/imaug/make_shrink_map.py @@ -38,6 +38,10 @@ class MakeShrinkMap(object): def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs): self.min_text_size = min_text_size self.shrink_ratio = shrink_ratio + if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[ + 'epoch'] != "None": + self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[ + 'epoch'] / float(kwargs['total_epoch']) def __call__(self, data): image = data['image'] diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 402f1e38fed9e32722e2dd160f10f779028807a3..5ce873f93a2ff531f1be2fd3797e0d4e920e04be 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -48,11 +48,25 @@ class SimpleDataSet(Dataset): self.data_idx_order_list = list(range(len(self.data_lines))) if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() + + self.set_epoch_as_seed(self.seed) + self.ops = create_operators(dataset_config['transforms'], global_config) self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2) self.need_reset = True in [x < 1 for x in ratio_list] + def set_epoch_as_seed(self, seed): + if self.mode is 'train': + try: + dataset_config['transforms'][5]['MakeBorderMap'][ + 'epoch'] = seed if seed is not None else 0 + dataset_config['transforms'][6]['MakeShrinkMap'][ + 'epoch'] = seed if seed is not None else 0 + except Exception as E: + logger.info(E) + return + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py index 708ffbdb47f349304e2bfd781a836e79348475f4..ce31ef124591ce3e5351460eb94ca50490bcf0e5 100755 --- a/ppocr/losses/det_db_loss.py +++ b/ppocr/losses/det_db_loss.py @@ -20,6 +20,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import paddle from paddle import nn from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss @@ -66,11 +67,21 @@ class DBLoss(nn.Layer): label_shrink_mask) loss_shrink_maps = self.alpha * loss_shrink_maps loss_threshold_maps = self.beta * loss_threshold_maps + # CBN loss + if 'distance_maps' in predicts.keys(): + distance_maps = predicts['distance_maps'] + cbn_maps = predicts['cbn_maps'] + cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map, + label_shrink_mask) + else: + dis_loss = paddle.to_tensor([0.]) + cbn_loss = paddle.to_tensor([0.]) loss_all = loss_shrink_maps + loss_threshold_maps \ + loss_binary_maps - losses = {'loss': loss_all, \ + losses = {'loss': loss_all+ cbn_loss, \ "loss_shrink_maps": loss_shrink_maps, \ "loss_threshold_maps": loss_threshold_maps, \ - "loss_binary_maps": loss_binary_maps} + "loss_binary_maps": loss_binary_maps, \ + "loss_cbn": cbn_loss} return losses diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 2e8f4c05c8dd08fe21eaed22f55c4c81c6131b5f..6dca38f1d969eaff637e39fe6f361bc6f0e0921d 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -22,8 +22,11 @@ def build_backbone(config, model_type): from .det_resnet_vd import ResNet_vd from .det_resnet_vd_sast import ResNet_SAST from .det_pp_lcnet import PPLCNet + from .rec_lcnetv3 import LCNetv3 + from .rec_hgnet import PPHGNet_small support_dict = [ - "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet" + "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet", + "LCNetv3", "PPHGNet_small" ] if model_type == "table": from .table_master_resnet import TableResNetExtra diff --git a/ppocr/modeling/backbones/rec_hgnet.py b/ppocr/modeling/backbones/rec_hgnet.py index c4f2d7bdee1cd5b9fc37e98f20f8b7637901ef87..d990453308a47f3e68f2d899c01edf3ecbdae8db 100644 --- a/ppocr/modeling/backbones/rec_hgnet.py +++ b/ppocr/modeling/backbones/rec_hgnet.py @@ -188,8 +188,19 @@ class PPHGNet(nn.Layer): model: nn.Layer. Specific PPHGNet model depends on args. """ - def __init__(self, stem_channels, stage_config, layer_num, in_channels=3): + def __init__( + self, + stem_channels, + stage_config, + layer_num, + in_channels=3, + det=False, + out_indices=None, ): super().__init__() + self.det = det + self.out_indices = out_indices if out_indices is not None else [ + 0, 1, 2, 3 + ] # stem stem_channels.insert(0, in_channels) @@ -202,16 +213,23 @@ class PPHGNet(nn.Layer): len(stem_channels) - 1) ]) + if self.det: + self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) # stages self.stages = nn.LayerList() - for k in stage_config: + self.out_channels = [] + for block_id, k in enumerate(stage_config): in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[ k] self.stages.append( HG_Stage(in_channels, mid_channels, out_channels, block_num, layer_num, downsample, stride)) + if block_id in self.out_indices: + self.out_channels.append(out_channels) + + if not self.det: + self.out_channels = stage_config["stage4"][2] - self.out_channels = stage_config["stage4"][2] self._init_weights() def _init_weights(self): @@ -226,8 +244,17 @@ class PPHGNet(nn.Layer): def forward(self, x): x = self.stem(x) - for stage in self.stages: + if self.det: + x = self.pool(x) + + out = [] + for i, stage in enumerate(self.stages): x = stage(x) + if self.det and i in self.out_indices: + out.append(x) + if self.det: + return out + if self.training: x = F.adaptive_avg_pool2d(x, [1, 40]) else: @@ -261,7 +288,7 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs): return model -def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): +def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs): """ PPHGNet_small Args: @@ -271,7 +298,15 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): Returns: model: nn.Layer. Specific `PPHGNet_small` model depends on args. """ - stage_config = { + stage_config_det = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [128, 128, 256, 1, False, 2], + "stage2": [256, 160, 512, 1, True, 2], + "stage3": [512, 192, 768, 2, True, 2], + "stage4": [768, 224, 1024, 1, True, 2], + } + + stage_config_rec = { # in_channels, mid_channels, out_channels, blocks, downsample "stage1": [128, 128, 256, 1, True, [2, 1]], "stage2": [256, 160, 512, 1, True, [1, 2]], @@ -281,8 +316,9 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): model = PPHGNet( stem_channels=[64, 64, 128], - stage_config=stage_config, + stage_config=stage_config_det if det else stage_config_rec, layer_num=6, + det=det, **kwargs) return model diff --git a/ppocr/modeling/backbones/rec_lcnetv3.py b/ppocr/modeling/backbones/rec_lcnetv3.py index 06232691c5905bce68b52c5f5a99de67353f2bc7..8f3b0c560c5859532b5cc829583810058e11bd42 100644 --- a/ppocr/modeling/backbones/rec_lcnetv3.py +++ b/ppocr/modeling/backbones/rec_lcnetv3.py @@ -24,7 +24,20 @@ from paddle.nn.initializer import Constant, KaimingNormal from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU from paddle.regularizer import L2Decay -NET_CONFIG = { +NET_CONFIG_det = { + "blocks2": + #k, in_c, out_c, s, use_se + [[3, 16, 32, 1, False]], + "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], + "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], + "blocks5": + [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False]], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True], + [5, 512, 512, 1, False], [5, 512, 512, 1, False]] +} + +NET_CONFIG_rec = { "blocks2": #k, in_c, out_c, s, use_se [[3, 16, 32, 1, False]], @@ -335,11 +348,14 @@ class PPLCNetV3(nn.Layer): conv_kxk_num=4, lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], lab_lr=0.1, + det=False, **kwargs): super().__init__() self.scale = scale self.lr_mult_list = lr_mult_list - self.net_config = NET_CONFIG + self.det = det + + self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec assert isinstance(self.lr_mult_list, ( list, tuple @@ -365,8 +381,9 @@ class PPLCNetV3(nn.Layer): use_se=se, conv_kxk_num=conv_kxk_num, lr_mult=self.lr_mult_list[1], - lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( - self.net_config["blocks2"]) + lab_lr=lab_lr) + for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[ + "blocks2"]) ]) self.blocks3 = nn.Sequential(* [ @@ -378,8 +395,9 @@ class PPLCNetV3(nn.Layer): use_se=se, conv_kxk_num=conv_kxk_num, lr_mult=self.lr_mult_list[2], - lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( - self.net_config["blocks3"]) + lab_lr=lab_lr) + for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[ + "blocks3"]) ]) self.blocks4 = nn.Sequential(* [ @@ -391,8 +409,9 @@ class PPLCNetV3(nn.Layer): use_se=se, conv_kxk_num=conv_kxk_num, lr_mult=self.lr_mult_list[3], - lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( - self.net_config["blocks4"]) + lab_lr=lab_lr) + for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[ + "blocks4"]) ]) self.blocks5 = nn.Sequential(* [ @@ -404,8 +423,9 @@ class PPLCNetV3(nn.Layer): use_se=se, conv_kxk_num=conv_kxk_num, lr_mult=self.lr_mult_list[4], - lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( - self.net_config["blocks5"]) + lab_lr=lab_lr) + for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[ + "blocks5"]) ]) self.blocks6 = nn.Sequential(* [ @@ -417,19 +437,52 @@ class PPLCNetV3(nn.Layer): use_se=se, conv_kxk_num=conv_kxk_num, lr_mult=self.lr_mult_list[5], - lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( - self.net_config["blocks6"]) + lab_lr=lab_lr) + for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[ + "blocks6"]) ]) self.out_channels = make_divisible(512 * scale) + if self.det: + mv_c = [16, 24, 56, 480] + self.out_channels = [ + make_divisible(self.net_config["blocks3"][-1][2] * scale), + make_divisible(self.net_config["blocks4"][-1][2] * scale), + make_divisible(self.net_config["blocks5"][-1][2] * scale), + make_divisible(self.net_config["blocks6"][-1][2] * scale), + ] + + self.layer_list = nn.LayerList([ + nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0), + nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0), + nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0), + nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0) + ]) + self.out_channels = [ + int(mv_c[0] * scale), int(mv_c[1] * scale), + int(mv_c[2] * scale), int(mv_c[3] * scale) + ] + def forward(self, x): + out_list = [] x = self.conv1(x) x = self.blocks2(x) x = self.blocks3(x) + out_list.append(x) x = self.blocks4(x) + out_list.append(x) x = self.blocks5(x) + out_list.append(x) x = self.blocks6(x) + out_list.append(x) + + if self.det: + out_list[0] = self.layer_list[0](out_list[0]) + out_list[1] = self.layer_list[1](out_list[1]) + out_list[2] = self.layer_list[2](out_list[2]) + out_list[3] = self.layer_list[3](out_list[3]) + return out_list if self.training: x = F.adaptive_avg_pool2d(x, [1, 40]) @@ -438,6 +491,6 @@ class PPLCNetV3(nn.Layer): return x -def LCNetv3(pretrained=False, use_ssld=False, **kwargs): - model = PPLCNetV3(scale=0.95, conv_kxk_num=4, **kwargs) +def LCNetv3(scale=0.95, **kwargs): + model = PPLCNetV3(scale=scale, conv_kxk_num=4, **kwargs) return model diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 794bc3a357c7f6efb314164be111fcc42ffab77e..96ef1a31c9845f6064e2cf2d1be371c8fa16ec4a 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -17,14 +17,13 @@ __all__ = ['build_head'] def build_head(config): # det head - from .det_db_head import DBHead + from .det_db_head import DBHead, CBNHeadLocal from .det_east_head import EASTHead from .det_sast_head import SASTHead from .det_pse_head import PSEHead from .det_fce_head import FCEHead from .e2e_pg_head import PGHead from .det_ct_head import CT_Head - # rec head from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead @@ -57,7 +56,7 @@ def build_head(config): 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', - 'DRRGHead', 'CANHead', 'SATRNHead' + 'DRRGHead', 'CANHead', 'SATRNHead', 'CBNHeadLocal' ] if config['name'] == 'DRRGHead': diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index 77cb6f1db2dda92dbe74803af68c6ec87ed2d583..968884f74762e35fd777f608a7c4e03d877f82bb 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -21,6 +21,7 @@ import paddle from paddle import nn import paddle.nn.functional as F from paddle import ParamAttr +from ppocr.modeling.backbones.det_mobilenet_v3 import ConvBNLayer def get_bias_attr(k): @@ -48,6 +49,7 @@ class Head(nn.Layer): bias_attr=ParamAttr( initializer=paddle.nn.initializer.Constant(value=1e-4)), act='relu') + self.conv2 = nn.Conv2DTranspose( in_channels=in_channels // 4, out_channels=in_channels // 4, @@ -72,13 +74,17 @@ class Head(nn.Layer): initializer=paddle.nn.initializer.KaimingUniform()), bias_attr=get_bias_attr(in_channels // 4), ) - def forward(self, x): + def forward(self, x, return_f=False): x = self.conv1(x) x = self.conv_bn1(x) x = self.conv2(x) x = self.conv_bn2(x) + if return_f is True: + f = x x = self.conv3(x) x = F.sigmoid(x) + if return_f is True: + return x, f return x @@ -108,3 +114,41 @@ class DBHead(nn.Layer): binary_maps = self.step_function(shrink_maps, threshold_maps) y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) return {'maps': y} + + +class LocalModule(nn.Layer): + def __init__(self, in_c, mid_c, use_distance=True): + super(self.__class__, self).__init__() + self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu') + self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0) + + def forward(self, x, init_map, distance_map): + outf = paddle.concat([init_map, x], axis=1) + # last Conv + out = self.last_1(self.last_3(outf)) + return out + + +class CBNHeadLocal(DBHead): + def __init__(self, in_channels, k=50, mode='small', **kwargs): + super(CBNHeadLocal, self).__init__(in_channels, k, **kwargs) + self.mode = mode + + self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1) + if self.mode == 'large': + self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4) + elif self.mode == 'small': + self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8) + + def forward(self, x, targets=None): + shrink_maps, f = self.binarize(x, return_f=True) + base_maps = shrink_maps + cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None) + cbn_maps = F.sigmoid(cbn_maps) + if not self.training: + return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps} + + threshold_maps = self.thresh(x) + binary_maps = self.step_function(shrink_maps, threshold_maps) + y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1) + return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps} diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 8c3f52a331db5daafab2a38c0a441edd44eb141d..0f5b826bfb023895d6216605e2b2faf82023fa80 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F from paddle import ParamAttr import os import sys +from ppocr.modeling.necks.intracl import IntraCLBlock __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) @@ -228,6 +229,13 @@ class RSEFPN(nn.Layer): self.out_channels = out_channels self.ins_conv = nn.LayerList() self.inp_conv = nn.LayerList() + self.intracl = False + if 'intracl' in kwargs.keys() and kwargs['intracl'] is True: + self.intracl = kwargs['intracl'] + self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) for i in range(len(in_channels)): self.ins_conv.append( @@ -263,6 +271,12 @@ class RSEFPN(nn.Layer): p3 = self.inp_conv[1](out3) p2 = self.inp_conv[0](out2) + if self.intracl is True: + p5 = self.incl4(p5) + p4 = self.incl3(p4) + p3 = self.incl2(p3) + p2 = self.incl1(p2) + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) @@ -329,6 +343,14 @@ class LKPAN(nn.Layer): weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)) + self.intracl = False + if 'intracl' in kwargs.keys() and kwargs['intracl'] is True: + self.intracl = kwargs['intracl'] + self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + def forward(self, x): c2, c3, c4, c5 = x @@ -358,6 +380,12 @@ class LKPAN(nn.Layer): p4 = self.pan_lat_conv[2](pan4) p5 = self.pan_lat_conv[3](pan5) + if self.intracl is True: + p5 = self.incl4(p5) + p4 = self.incl3(p4) + p3 = self.incl2(p3) + p2 = self.incl1(p2) + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) @@ -424,4 +452,4 @@ class ASFBlock(nn.Layer): out_list = [] for i in range(self.out_features_num): out_list.append(attention_scores[:, i:i + 1] * features_list[i]) - return paddle.concat(out_list, axis=1) + return paddle.concat(out_list, axis=1) \ No newline at end of file diff --git a/ppocr/modeling/necks/intracl.py b/ppocr/modeling/necks/intracl.py new file mode 100644 index 0000000000000000000000000000000000000000..205b52e35f04e59d35ae6a89bfe1b920a3890d5f --- /dev/null +++ b/ppocr/modeling/necks/intracl.py @@ -0,0 +1,118 @@ +import paddle +from paddle import nn + +# refer from: https://github.com/ViTAE-Transformer/I3CL/blob/736c80237f66d352d488e83b05f3e33c55201317/mmdet/models/detectors/intra_cl_module.py + + +class IntraCLBlock(nn.Layer): + def __init__(self, in_channels=96, reduce_factor=4): + super(IntraCLBlock, self).__init__() + self.channels = in_channels + self.rf = reduce_factor + weight_attr = paddle.nn.initializer.KaimingUniform() + self.conv1x1_reduce_channel = nn.Conv2D( + self.channels, + self.channels // self.rf, + kernel_size=1, + stride=1, + padding=0) + self.conv1x1_return_channel = nn.Conv2D( + self.channels // self.rf, + self.channels, + kernel_size=1, + stride=1, + padding=0) + + self.v_layer_7x1 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(7, 1), + stride=(1, 1), + padding=(3, 0)) + self.v_layer_5x1 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(5, 1), + stride=(1, 1), + padding=(2, 0)) + self.v_layer_3x1 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(3, 1), + stride=(1, 1), + padding=(1, 0)) + + self.q_layer_1x7 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(1, 7), + stride=(1, 1), + padding=(0, 3)) + self.q_layer_1x5 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(1, 5), + stride=(1, 1), + padding=(0, 2)) + self.q_layer_1x3 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(1, 3), + stride=(1, 1), + padding=(0, 1)) + + # base + self.c_layer_7x7 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(7, 7), + stride=(1, 1), + padding=(3, 3)) + self.c_layer_5x5 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2)) + self.c_layer_3x3 = nn.Conv2D( + self.channels // self.rf, + self.channels // self.rf, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1)) + + self.bn = nn.BatchNorm2D(self.channels) + self.relu = nn.ReLU() + + def forward(self, x): + x_new = self.conv1x1_reduce_channel(x) + + x_7_c = self.c_layer_7x7(x_new) + x_7_v = self.v_layer_7x1(x_new) + x_7_q = self.q_layer_1x7(x_new) + x_7 = x_7_c + x_7_v + x_7_q + + x_5_c = self.c_layer_5x5(x_7) + x_5_v = self.v_layer_5x1(x_7) + x_5_q = self.q_layer_1x5(x_7) + x_5 = x_5_c + x_5_v + x_5_q + + x_3_c = self.c_layer_3x3(x_5) + x_3_v = self.v_layer_3x1(x_5) + x_3_q = self.q_layer_1x3(x_5) + x_3 = x_3_c + x_3_v + x_3_q + + x_relation = self.conv1x1_return_channel(x_3) + + x_relation = self.bn(x_relation) + x_relation = self.relu(x_relation) + + return x + x_relation + + +def build_intraclblock_list(num_block): + IntraCLBlock_list = nn.LayerList() + for i in range(num_block): + IntraCLBlock_list.append(IntraCLBlock()) + + return IntraCLBlock_list \ No newline at end of file