未验证 提交 ca8c8200 编写于 作者: H huangjun12 提交者: GitHub

add PP-OCRv4 det code (#9766)

* add ppocrv4 det student and teacher model

* update head and config, refine details

* refine config and head details

* refine config and head details

* refine details

* refine details

* remove application

* refine fpn

* fix bug

* update code

* fix bug

* align lcnet to rec

* align hgnet to rec

* refine make shrink

* remove theseus layer
上级 7710ee04
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
save_epoch_step: 10
eval_batch_step:
- 0
- 1500
cal_metric_during_train: false
checkpoints:
pretrained_model:
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: LCNetv3
scale: 0.75
det: True
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: CBNHeadLocal
k: 50
mode: "small"
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001 #(8*8c)
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
total_epoch: 500
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
total_epoch: 500
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
save_epoch_step: 10
eval_batch_step:
- 0
- 1500
cal_metric_during_train: false
checkpoints:
pretrained_model:
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: PPHGNet_small
det: True
Neck:
name: LKPAN
out_channels: 256
intracl: true
Head:
name: CBNHeadLocal
k: 50
mode: "large"
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001 #(8*8c)
warmup_epoch: 2
regularizer:
name: L2
factor: 1e-6
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
total_epoch: 500
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
total_epoch: 500
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
...@@ -44,6 +44,10 @@ class MakeBorderMap(object): ...@@ -44,6 +44,10 @@ class MakeBorderMap(object):
self.shrink_ratio = shrink_ratio self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min self.thresh_min = thresh_min
self.thresh_max = thresh_max self.thresh_max = thresh_max
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
'epoch'] != "None":
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
'epoch'] / float(kwargs['total_epoch'])
def __call__(self, data): def __call__(self, data):
......
...@@ -38,6 +38,10 @@ class MakeShrinkMap(object): ...@@ -38,6 +38,10 @@ class MakeShrinkMap(object):
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs): def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio self.shrink_ratio = shrink_ratio
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
'epoch'] != "None":
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
'epoch'] / float(kwargs['total_epoch'])
def __call__(self, data): def __call__(self, data):
image = data['image'] image = data['image']
......
...@@ -48,11 +48,25 @@ class SimpleDataSet(Dataset): ...@@ -48,11 +48,25 @@ class SimpleDataSet(Dataset):
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle: if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.set_epoch_as_seed(self.seed)
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2) 2)
self.need_reset = True in [x < 1 for x in ratio_list] self.need_reset = True in [x < 1 for x in ratio_list]
def set_epoch_as_seed(self, seed):
if self.mode is 'train':
try:
dataset_config['transforms'][5]['MakeBorderMap'][
'epoch'] = seed if seed is not None else 0
dataset_config['transforms'][6]['MakeShrinkMap'][
'epoch'] = seed if seed is not None else 0
except Exception as E:
logger.info(E)
return
def get_image_info_list(self, file_list, ratio_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
......
...@@ -20,6 +20,7 @@ from __future__ import absolute_import ...@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from paddle import nn from paddle import nn
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
...@@ -66,11 +67,21 @@ class DBLoss(nn.Layer): ...@@ -66,11 +67,21 @@ class DBLoss(nn.Layer):
label_shrink_mask) label_shrink_mask)
loss_shrink_maps = self.alpha * loss_shrink_maps loss_shrink_maps = self.alpha * loss_shrink_maps
loss_threshold_maps = self.beta * loss_threshold_maps loss_threshold_maps = self.beta * loss_threshold_maps
# CBN loss
if 'distance_maps' in predicts.keys():
distance_maps = predicts['distance_maps']
cbn_maps = predicts['cbn_maps']
cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
label_shrink_mask)
else:
dis_loss = paddle.to_tensor([0.])
cbn_loss = paddle.to_tensor([0.])
loss_all = loss_shrink_maps + loss_threshold_maps \ loss_all = loss_shrink_maps + loss_threshold_maps \
+ loss_binary_maps + loss_binary_maps
losses = {'loss': loss_all, \ losses = {'loss': loss_all+ cbn_loss, \
"loss_shrink_maps": loss_shrink_maps, \ "loss_shrink_maps": loss_shrink_maps, \
"loss_threshold_maps": loss_threshold_maps, \ "loss_threshold_maps": loss_threshold_maps, \
"loss_binary_maps": loss_binary_maps} "loss_binary_maps": loss_binary_maps, \
"loss_cbn": cbn_loss}
return losses return losses
...@@ -22,8 +22,11 @@ def build_backbone(config, model_type): ...@@ -22,8 +22,11 @@ def build_backbone(config, model_type):
from .det_resnet_vd import ResNet_vd from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
from .det_pp_lcnet import PPLCNet from .det_pp_lcnet import PPLCNet
from .rec_lcnetv3 import LCNetv3
from .rec_hgnet import PPHGNet_small
support_dict = [ support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet" "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
"LCNetv3", "PPHGNet_small"
] ]
if model_type == "table": if model_type == "table":
from .table_master_resnet import TableResNetExtra from .table_master_resnet import TableResNetExtra
......
...@@ -188,8 +188,19 @@ class PPHGNet(nn.Layer): ...@@ -188,8 +188,19 @@ class PPHGNet(nn.Layer):
model: nn.Layer. Specific PPHGNet model depends on args. model: nn.Layer. Specific PPHGNet model depends on args.
""" """
def __init__(self, stem_channels, stage_config, layer_num, in_channels=3): def __init__(
self,
stem_channels,
stage_config,
layer_num,
in_channels=3,
det=False,
out_indices=None, ):
super().__init__() super().__init__()
self.det = det
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
# stem # stem
stem_channels.insert(0, in_channels) stem_channels.insert(0, in_channels)
...@@ -202,16 +213,23 @@ class PPHGNet(nn.Layer): ...@@ -202,16 +213,23 @@ class PPHGNet(nn.Layer):
len(stem_channels) - 1) len(stem_channels) - 1)
]) ])
if self.det:
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
# stages # stages
self.stages = nn.LayerList() self.stages = nn.LayerList()
for k in stage_config: self.out_channels = []
for block_id, k in enumerate(stage_config):
in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[ in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
k] k]
self.stages.append( self.stages.append(
HG_Stage(in_channels, mid_channels, out_channels, block_num, HG_Stage(in_channels, mid_channels, out_channels, block_num,
layer_num, downsample, stride)) layer_num, downsample, stride))
if block_id in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.out_channels = stage_config["stage4"][2]
self._init_weights() self._init_weights()
def _init_weights(self): def _init_weights(self):
...@@ -226,8 +244,17 @@ class PPHGNet(nn.Layer): ...@@ -226,8 +244,17 @@ class PPHGNet(nn.Layer):
def forward(self, x): def forward(self, x):
x = self.stem(x) x = self.stem(x)
for stage in self.stages: if self.det:
x = self.pool(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x) x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.training: if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40]) x = F.adaptive_avg_pool2d(x, [1, 40])
else: else:
...@@ -261,7 +288,7 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs): ...@@ -261,7 +288,7 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
return model return model
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
""" """
PPHGNet_small PPHGNet_small
Args: Args:
...@@ -271,7 +298,15 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): ...@@ -271,7 +298,15 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
Returns: Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args. model: nn.Layer. Specific `PPHGNet_small` model depends on args.
""" """
stage_config = { stage_config_det = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, False, 2],
"stage2": [256, 160, 512, 1, True, 2],
"stage3": [512, 192, 768, 2, True, 2],
"stage4": [768, 224, 1024, 1, True, 2],
}
stage_config_rec = {
# in_channels, mid_channels, out_channels, blocks, downsample # in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, True, [2, 1]], "stage1": [128, 128, 256, 1, True, [2, 1]],
"stage2": [256, 160, 512, 1, True, [1, 2]], "stage2": [256, 160, 512, 1, True, [1, 2]],
...@@ -281,8 +316,9 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): ...@@ -281,8 +316,9 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
model = PPHGNet( model = PPHGNet(
stem_channels=[64, 64, 128], stem_channels=[64, 64, 128],
stage_config=stage_config, stage_config=stage_config_det if det else stage_config_rec,
layer_num=6, layer_num=6,
det=det,
**kwargs) **kwargs)
return model return model
......
...@@ -24,7 +24,20 @@ from paddle.nn.initializer import Constant, KaimingNormal ...@@ -24,7 +24,20 @@ from paddle.nn.initializer import Constant, KaimingNormal
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
NET_CONFIG = { NET_CONFIG_det = {
"blocks2":
#k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5":
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True],
[5, 512, 512, 1, False], [5, 512, 512, 1, False]]
}
NET_CONFIG_rec = {
"blocks2": "blocks2":
#k, in_c, out_c, s, use_se #k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]], [[3, 16, 32, 1, False]],
...@@ -335,11 +348,14 @@ class PPLCNetV3(nn.Layer): ...@@ -335,11 +348,14 @@ class PPLCNetV3(nn.Layer):
conv_kxk_num=4, conv_kxk_num=4,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lab_lr=0.1, lab_lr=0.1,
det=False,
**kwargs): **kwargs):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.lr_mult_list = lr_mult_list self.lr_mult_list = lr_mult_list
self.net_config = NET_CONFIG self.det = det
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
assert isinstance(self.lr_mult_list, ( assert isinstance(self.lr_mult_list, (
list, tuple list, tuple
...@@ -365,8 +381,9 @@ class PPLCNetV3(nn.Layer): ...@@ -365,8 +381,9 @@ class PPLCNetV3(nn.Layer):
use_se=se, use_se=se,
conv_kxk_num=conv_kxk_num, conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[1], lr_mult=self.lr_mult_list[1],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( lab_lr=lab_lr)
self.net_config["blocks2"]) for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks2"])
]) ])
self.blocks3 = nn.Sequential(* [ self.blocks3 = nn.Sequential(* [
...@@ -378,8 +395,9 @@ class PPLCNetV3(nn.Layer): ...@@ -378,8 +395,9 @@ class PPLCNetV3(nn.Layer):
use_se=se, use_se=se,
conv_kxk_num=conv_kxk_num, conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[2], lr_mult=self.lr_mult_list[2],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( lab_lr=lab_lr)
self.net_config["blocks3"]) for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks3"])
]) ])
self.blocks4 = nn.Sequential(* [ self.blocks4 = nn.Sequential(* [
...@@ -391,8 +409,9 @@ class PPLCNetV3(nn.Layer): ...@@ -391,8 +409,9 @@ class PPLCNetV3(nn.Layer):
use_se=se, use_se=se,
conv_kxk_num=conv_kxk_num, conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[3], lr_mult=self.lr_mult_list[3],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( lab_lr=lab_lr)
self.net_config["blocks4"]) for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks4"])
]) ])
self.blocks5 = nn.Sequential(* [ self.blocks5 = nn.Sequential(* [
...@@ -404,8 +423,9 @@ class PPLCNetV3(nn.Layer): ...@@ -404,8 +423,9 @@ class PPLCNetV3(nn.Layer):
use_se=se, use_se=se,
conv_kxk_num=conv_kxk_num, conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[4], lr_mult=self.lr_mult_list[4],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( lab_lr=lab_lr)
self.net_config["blocks5"]) for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks5"])
]) ])
self.blocks6 = nn.Sequential(* [ self.blocks6 = nn.Sequential(* [
...@@ -417,19 +437,52 @@ class PPLCNetV3(nn.Layer): ...@@ -417,19 +437,52 @@ class PPLCNetV3(nn.Layer):
use_se=se, use_se=se,
conv_kxk_num=conv_kxk_num, conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[5], lr_mult=self.lr_mult_list[5],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate( lab_lr=lab_lr)
self.net_config["blocks6"]) for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks6"])
]) ])
self.out_channels = make_divisible(512 * scale) self.out_channels = make_divisible(512 * scale)
if self.det:
mv_c = [16, 24, 56, 480]
self.out_channels = [
make_divisible(self.net_config["blocks3"][-1][2] * scale),
make_divisible(self.net_config["blocks4"][-1][2] * scale),
make_divisible(self.net_config["blocks5"][-1][2] * scale),
make_divisible(self.net_config["blocks6"][-1][2] * scale),
]
self.layer_list = nn.LayerList([
nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0)
])
self.out_channels = [
int(mv_c[0] * scale), int(mv_c[1] * scale),
int(mv_c[2] * scale), int(mv_c[3] * scale)
]
def forward(self, x): def forward(self, x):
out_list = []
x = self.conv1(x) x = self.conv1(x)
x = self.blocks2(x) x = self.blocks2(x)
x = self.blocks3(x) x = self.blocks3(x)
out_list.append(x)
x = self.blocks4(x) x = self.blocks4(x)
out_list.append(x)
x = self.blocks5(x) x = self.blocks5(x)
out_list.append(x)
x = self.blocks6(x) x = self.blocks6(x)
out_list.append(x)
if self.det:
out_list[0] = self.layer_list[0](out_list[0])
out_list[1] = self.layer_list[1](out_list[1])
out_list[2] = self.layer_list[2](out_list[2])
out_list[3] = self.layer_list[3](out_list[3])
return out_list
if self.training: if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40]) x = F.adaptive_avg_pool2d(x, [1, 40])
...@@ -438,6 +491,6 @@ class PPLCNetV3(nn.Layer): ...@@ -438,6 +491,6 @@ class PPLCNetV3(nn.Layer):
return x return x
def LCNetv3(pretrained=False, use_ssld=False, **kwargs): def LCNetv3(scale=0.95, **kwargs):
model = PPLCNetV3(scale=0.95, conv_kxk_num=4, **kwargs) model = PPLCNetV3(scale=scale, conv_kxk_num=4, **kwargs)
return model return model
...@@ -17,14 +17,13 @@ __all__ = ['build_head'] ...@@ -17,14 +17,13 @@ __all__ = ['build_head']
def build_head(config): def build_head(config):
# det head # det head
from .det_db_head import DBHead from .det_db_head import DBHead, CBNHeadLocal
from .det_east_head import EASTHead from .det_east_head import EASTHead
from .det_sast_head import SASTHead from .det_sast_head import SASTHead
from .det_pse_head import PSEHead from .det_pse_head import PSEHead
from .det_fce_head import FCEHead from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head from .det_ct_head import CT_Head
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead from .rec_att_head import AttentionHead
...@@ -57,7 +56,7 @@ def build_head(config): ...@@ -57,7 +56,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead', 'CANHead', 'SATRNHead' 'DRRGHead', 'CANHead', 'SATRNHead', 'CBNHeadLocal'
] ]
if config['name'] == 'DRRGHead': if config['name'] == 'DRRGHead':
......
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
from paddle import nn from paddle import nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from ppocr.modeling.backbones.det_mobilenet_v3 import ConvBNLayer
def get_bias_attr(k): def get_bias_attr(k):
...@@ -48,6 +49,7 @@ class Head(nn.Layer): ...@@ -48,6 +49,7 @@ class Head(nn.Layer):
bias_attr=ParamAttr( bias_attr=ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
act='relu') act='relu')
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=in_channels // 4, out_channels=in_channels // 4,
...@@ -72,13 +74,17 @@ class Head(nn.Layer): ...@@ -72,13 +74,17 @@ class Head(nn.Layer):
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4), ) bias_attr=get_bias_attr(in_channels // 4), )
def forward(self, x): def forward(self, x, return_f=False):
x = self.conv1(x) x = self.conv1(x)
x = self.conv_bn1(x) x = self.conv_bn1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.conv_bn2(x) x = self.conv_bn2(x)
if return_f is True:
f = x
x = self.conv3(x) x = self.conv3(x)
x = F.sigmoid(x) x = F.sigmoid(x)
if return_f is True:
return x, f
return x return x
...@@ -108,3 +114,41 @@ class DBHead(nn.Layer): ...@@ -108,3 +114,41 @@ class DBHead(nn.Layer):
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
return {'maps': y} return {'maps': y}
class LocalModule(nn.Layer):
def __init__(self, in_c, mid_c, use_distance=True):
super(self.__class__, self).__init__()
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0)
def forward(self, x, init_map, distance_map):
outf = paddle.concat([init_map, x], axis=1)
# last Conv
out = self.last_1(self.last_3(outf))
return out
class CBNHeadLocal(DBHead):
def __init__(self, in_channels, k=50, mode='small', **kwargs):
super(CBNHeadLocal, self).__init__(in_channels, k, **kwargs)
self.mode = mode
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1)
if self.mode == 'large':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
elif self.mode == 'small':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
def forward(self, x, targets=None):
shrink_maps, f = self.binarize(x, return_f=True)
base_maps = shrink_maps
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
cbn_maps = F.sigmoid(cbn_maps)
if not self.training:
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1)
return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps}
...@@ -22,6 +22,7 @@ import paddle.nn.functional as F ...@@ -22,6 +22,7 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
import os import os
import sys import sys
from ppocr.modeling.necks.intracl import IntraCLBlock
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -228,6 +229,13 @@ class RSEFPN(nn.Layer): ...@@ -228,6 +229,13 @@ class RSEFPN(nn.Layer):
self.out_channels = out_channels self.out_channels = out_channels
self.ins_conv = nn.LayerList() self.ins_conv = nn.LayerList()
self.inp_conv = nn.LayerList() self.inp_conv = nn.LayerList()
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
for i in range(len(in_channels)): for i in range(len(in_channels)):
self.ins_conv.append( self.ins_conv.append(
...@@ -263,6 +271,12 @@ class RSEFPN(nn.Layer): ...@@ -263,6 +271,12 @@ class RSEFPN(nn.Layer):
p3 = self.inp_conv[1](out3) p3 = self.inp_conv[1](out3)
p2 = self.inp_conv[0](out2) p2 = self.inp_conv[0](out2)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
...@@ -329,6 +343,14 @@ class LKPAN(nn.Layer): ...@@ -329,6 +343,14 @@ class LKPAN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr), weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)) bias_attr=False))
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
def forward(self, x): def forward(self, x):
c2, c3, c4, c5 = x c2, c3, c4, c5 = x
...@@ -358,6 +380,12 @@ class LKPAN(nn.Layer): ...@@ -358,6 +380,12 @@ class LKPAN(nn.Layer):
p4 = self.pan_lat_conv[2](pan4) p4 = self.pan_lat_conv[2](pan4)
p5 = self.pan_lat_conv[3](pan5) p5 = self.pan_lat_conv[3](pan5)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
...@@ -424,4 +452,4 @@ class ASFBlock(nn.Layer): ...@@ -424,4 +452,4 @@ class ASFBlock(nn.Layer):
out_list = [] out_list = []
for i in range(self.out_features_num): for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i]) out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return paddle.concat(out_list, axis=1) return paddle.concat(out_list, axis=1)
\ No newline at end of file
import paddle
from paddle import nn
# refer from: https://github.com/ViTAE-Transformer/I3CL/blob/736c80237f66d352d488e83b05f3e33c55201317/mmdet/models/detectors/intra_cl_module.py
class IntraCLBlock(nn.Layer):
def __init__(self, in_channels=96, reduce_factor=4):
super(IntraCLBlock, self).__init__()
self.channels = in_channels
self.rf = reduce_factor
weight_attr = paddle.nn.initializer.KaimingUniform()
self.conv1x1_reduce_channel = nn.Conv2D(
self.channels,
self.channels // self.rf,
kernel_size=1,
stride=1,
padding=0)
self.conv1x1_return_channel = nn.Conv2D(
self.channels // self.rf,
self.channels,
kernel_size=1,
stride=1,
padding=0)
self.v_layer_7x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 1),
stride=(1, 1),
padding=(3, 0))
self.v_layer_5x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 1),
stride=(1, 1),
padding=(2, 0))
self.v_layer_3x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 1),
stride=(1, 1),
padding=(1, 0))
self.q_layer_1x7 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 7),
stride=(1, 1),
padding=(0, 3))
self.q_layer_1x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 5),
stride=(1, 1),
padding=(0, 2))
self.q_layer_1x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 3),
stride=(1, 1),
padding=(0, 1))
# base
self.c_layer_7x7 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 7),
stride=(1, 1),
padding=(3, 3))
self.c_layer_5x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2))
self.c_layer_3x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.bn = nn.BatchNorm2D(self.channels)
self.relu = nn.ReLU()
def forward(self, x):
x_new = self.conv1x1_reduce_channel(x)
x_7_c = self.c_layer_7x7(x_new)
x_7_v = self.v_layer_7x1(x_new)
x_7_q = self.q_layer_1x7(x_new)
x_7 = x_7_c + x_7_v + x_7_q
x_5_c = self.c_layer_5x5(x_7)
x_5_v = self.v_layer_5x1(x_7)
x_5_q = self.q_layer_1x5(x_7)
x_5 = x_5_c + x_5_v + x_5_q
x_3_c = self.c_layer_3x3(x_5)
x_3_v = self.v_layer_3x1(x_5)
x_3_q = self.q_layer_1x3(x_5)
x_3 = x_3_c + x_3_v + x_3_q
x_relation = self.conv1x1_return_channel(x_3)
x_relation = self.bn(x_relation)
x_relation = self.relu(x_relation)
return x + x_relation
def build_intraclblock_list(num_block):
IntraCLBlock_list = nn.LayerList()
for i in range(num_block):
IntraCLBlock_list.append(IntraCLBlock())
return IntraCLBlock_list
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册