未验证 提交 ca8c8200 编写于 作者: H huangjun12 提交者: GitHub

add PP-OCRv4 det code (#9766)

* add ppocrv4 det student and teacher model

* update head and config, refine details

* refine config and head details

* refine config and head details

* refine details

* refine details

* remove application

* refine fpn

* fix bug

* update code

* fix bug

* align lcnet to rec

* align hgnet to rec

* refine make shrink

* remove theseus layer
上级 7710ee04
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
save_epoch_step: 10
eval_batch_step:
- 0
- 1500
cal_metric_during_train: false
checkpoints:
pretrained_model:
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: LCNetv3
scale: 0.75
det: True
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: CBNHeadLocal
k: 50
mode: "small"
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001 #(8*8c)
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
total_epoch: 500
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
total_epoch: 500
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
save_epoch_step: 10
eval_batch_step:
- 0
- 1500
cal_metric_during_train: false
checkpoints:
pretrained_model:
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: PPHGNet_small
det: True
Neck:
name: LKPAN
out_channels: 256
intracl: true
Head:
name: CBNHeadLocal
k: 50
mode: "large"
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001 #(8*8c)
warmup_epoch: 2
regularizer:
name: L2
factor: 1e-6
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
total_epoch: 500
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
total_epoch: 500
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
......@@ -44,6 +44,10 @@ class MakeBorderMap(object):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
'epoch'] != "None":
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
'epoch'] / float(kwargs['total_epoch'])
def __call__(self, data):
......
......@@ -38,6 +38,10 @@ class MakeShrinkMap(object):
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
'epoch'] != "None":
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
'epoch'] / float(kwargs['total_epoch'])
def __call__(self, data):
image = data['image']
......
......@@ -48,11 +48,25 @@ class SimpleDataSet(Dataset):
self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.set_epoch_as_seed(self.seed)
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2)
self.need_reset = True in [x < 1 for x in ratio_list]
def set_epoch_as_seed(self, seed):
if self.mode is 'train':
try:
dataset_config['transforms'][5]['MakeBorderMap'][
'epoch'] = seed if seed is not None else 0
dataset_config['transforms'][6]['MakeShrinkMap'][
'epoch'] = seed if seed is not None else 0
except Exception as E:
logger.info(E)
return
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
......@@ -66,11 +67,21 @@ class DBLoss(nn.Layer):
label_shrink_mask)
loss_shrink_maps = self.alpha * loss_shrink_maps
loss_threshold_maps = self.beta * loss_threshold_maps
# CBN loss
if 'distance_maps' in predicts.keys():
distance_maps = predicts['distance_maps']
cbn_maps = predicts['cbn_maps']
cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
label_shrink_mask)
else:
dis_loss = paddle.to_tensor([0.])
cbn_loss = paddle.to_tensor([0.])
loss_all = loss_shrink_maps + loss_threshold_maps \
+ loss_binary_maps
losses = {'loss': loss_all, \
losses = {'loss': loss_all+ cbn_loss, \
"loss_shrink_maps": loss_shrink_maps, \
"loss_threshold_maps": loss_threshold_maps, \
"loss_binary_maps": loss_binary_maps}
"loss_binary_maps": loss_binary_maps, \
"loss_cbn": cbn_loss}
return losses
......@@ -22,8 +22,11 @@ def build_backbone(config, model_type):
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
from .det_pp_lcnet import PPLCNet
from .rec_lcnetv3 import LCNetv3
from .rec_hgnet import PPHGNet_small
support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
"LCNetv3", "PPHGNet_small"
]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
......
......@@ -188,8 +188,19 @@ class PPHGNet(nn.Layer):
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def __init__(self, stem_channels, stage_config, layer_num, in_channels=3):
def __init__(
self,
stem_channels,
stage_config,
layer_num,
in_channels=3,
det=False,
out_indices=None, ):
super().__init__()
self.det = det
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
# stem
stem_channels.insert(0, in_channels)
......@@ -202,16 +213,23 @@ class PPHGNet(nn.Layer):
len(stem_channels) - 1)
])
if self.det:
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
# stages
self.stages = nn.LayerList()
for k in stage_config:
self.out_channels = []
for block_id, k in enumerate(stage_config):
in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
k]
self.stages.append(
HG_Stage(in_channels, mid_channels, out_channels, block_num,
layer_num, downsample, stride))
if block_id in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.out_channels = stage_config["stage4"][2]
self._init_weights()
def _init_weights(self):
......@@ -226,8 +244,17 @@ class PPHGNet(nn.Layer):
def forward(self, x):
x = self.stem(x)
for stage in self.stages:
if self.det:
x = self.pool(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
......@@ -261,7 +288,7 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
return model
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
"""
PPHGNet_small
Args:
......@@ -271,7 +298,15 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config = {
stage_config_det = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, False, 2],
"stage2": [256, 160, 512, 1, True, 2],
"stage3": [512, 192, 768, 2, True, 2],
"stage4": [768, 224, 1024, 1, True, 2],
}
stage_config_rec = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, True, [2, 1]],
"stage2": [256, 160, 512, 1, True, [1, 2]],
......@@ -281,8 +316,9 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
model = PPHGNet(
stem_channels=[64, 64, 128],
stage_config=stage_config,
stage_config=stage_config_det if det else stage_config_rec,
layer_num=6,
det=det,
**kwargs)
return model
......
......@@ -24,7 +24,20 @@ from paddle.nn.initializer import Constant, KaimingNormal
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU
from paddle.regularizer import L2Decay
NET_CONFIG = {
NET_CONFIG_det = {
"blocks2":
#k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5":
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True],
[5, 512, 512, 1, False], [5, 512, 512, 1, False]]
}
NET_CONFIG_rec = {
"blocks2":
#k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
......@@ -335,11 +348,14 @@ class PPLCNetV3(nn.Layer):
conv_kxk_num=4,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lab_lr=0.1,
det=False,
**kwargs):
super().__init__()
self.scale = scale
self.lr_mult_list = lr_mult_list
self.net_config = NET_CONFIG
self.det = det
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
assert isinstance(self.lr_mult_list, (
list, tuple
......@@ -365,8 +381,9 @@ class PPLCNetV3(nn.Layer):
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[1],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks2"])
lab_lr=lab_lr)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks2"])
])
self.blocks3 = nn.Sequential(* [
......@@ -378,8 +395,9 @@ class PPLCNetV3(nn.Layer):
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[2],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks3"])
lab_lr=lab_lr)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks3"])
])
self.blocks4 = nn.Sequential(* [
......@@ -391,8 +409,9 @@ class PPLCNetV3(nn.Layer):
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[3],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks4"])
lab_lr=lab_lr)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks4"])
])
self.blocks5 = nn.Sequential(* [
......@@ -404,8 +423,9 @@ class PPLCNetV3(nn.Layer):
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[4],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks5"])
lab_lr=lab_lr)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks5"])
])
self.blocks6 = nn.Sequential(* [
......@@ -417,19 +437,52 @@ class PPLCNetV3(nn.Layer):
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[5],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks6"])
lab_lr=lab_lr)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
"blocks6"])
])
self.out_channels = make_divisible(512 * scale)
if self.det:
mv_c = [16, 24, 56, 480]
self.out_channels = [
make_divisible(self.net_config["blocks3"][-1][2] * scale),
make_divisible(self.net_config["blocks4"][-1][2] * scale),
make_divisible(self.net_config["blocks5"][-1][2] * scale),
make_divisible(self.net_config["blocks6"][-1][2] * scale),
]
self.layer_list = nn.LayerList([
nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0)
])
self.out_channels = [
int(mv_c[0] * scale), int(mv_c[1] * scale),
int(mv_c[2] * scale), int(mv_c[3] * scale)
]
def forward(self, x):
out_list = []
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
out_list.append(x)
x = self.blocks4(x)
out_list.append(x)
x = self.blocks5(x)
out_list.append(x)
x = self.blocks6(x)
out_list.append(x)
if self.det:
out_list[0] = self.layer_list[0](out_list[0])
out_list[1] = self.layer_list[1](out_list[1])
out_list[2] = self.layer_list[2](out_list[2])
out_list[3] = self.layer_list[3](out_list[3])
return out_list
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
......@@ -438,6 +491,6 @@ class PPLCNetV3(nn.Layer):
return x
def LCNetv3(pretrained=False, use_ssld=False, **kwargs):
model = PPLCNetV3(scale=0.95, conv_kxk_num=4, **kwargs)
def LCNetv3(scale=0.95, **kwargs):
model = PPLCNetV3(scale=scale, conv_kxk_num=4, **kwargs)
return model
......@@ -17,14 +17,13 @@ __all__ = ['build_head']
def build_head(config):
# det head
from .det_db_head import DBHead
from .det_db_head import DBHead, CBNHeadLocal
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
from .det_pse_head import PSEHead
from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
......@@ -57,7 +56,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead', 'CANHead', 'SATRNHead'
'DRRGHead', 'CANHead', 'SATRNHead', 'CBNHeadLocal'
]
if config['name'] == 'DRRGHead':
......
......@@ -21,6 +21,7 @@ import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppocr.modeling.backbones.det_mobilenet_v3 import ConvBNLayer
def get_bias_attr(k):
......@@ -48,6 +49,7 @@ class Head(nn.Layer):
bias_attr=ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1e-4)),
act='relu')
self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
out_channels=in_channels // 4,
......@@ -72,13 +74,17 @@ class Head(nn.Layer):
initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4), )
def forward(self, x):
def forward(self, x, return_f=False):
x = self.conv1(x)
x = self.conv_bn1(x)
x = self.conv2(x)
x = self.conv_bn2(x)
if return_f is True:
f = x
x = self.conv3(x)
x = F.sigmoid(x)
if return_f is True:
return x, f
return x
......@@ -108,3 +114,41 @@ class DBHead(nn.Layer):
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
return {'maps': y}
class LocalModule(nn.Layer):
def __init__(self, in_c, mid_c, use_distance=True):
super(self.__class__, self).__init__()
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0)
def forward(self, x, init_map, distance_map):
outf = paddle.concat([init_map, x], axis=1)
# last Conv
out = self.last_1(self.last_3(outf))
return out
class CBNHeadLocal(DBHead):
def __init__(self, in_channels, k=50, mode='small', **kwargs):
super(CBNHeadLocal, self).__init__(in_channels, k, **kwargs)
self.mode = mode
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1)
if self.mode == 'large':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
elif self.mode == 'small':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
def forward(self, x, targets=None):
shrink_maps, f = self.binarize(x, return_f=True)
base_maps = shrink_maps
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
cbn_maps = F.sigmoid(cbn_maps)
if not self.training:
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1)
return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps}
......@@ -22,6 +22,7 @@ import paddle.nn.functional as F
from paddle import ParamAttr
import os
import sys
from ppocr.modeling.necks.intracl import IntraCLBlock
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
......@@ -228,6 +229,13 @@ class RSEFPN(nn.Layer):
self.out_channels = out_channels
self.ins_conv = nn.LayerList()
self.inp_conv = nn.LayerList()
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
for i in range(len(in_channels)):
self.ins_conv.append(
......@@ -263,6 +271,12 @@ class RSEFPN(nn.Layer):
p3 = self.inp_conv[1](out3)
p2 = self.inp_conv[0](out2)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
......@@ -329,6 +343,14 @@ class LKPAN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False))
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
def forward(self, x):
c2, c3, c4, c5 = x
......@@ -358,6 +380,12 @@ class LKPAN(nn.Layer):
p4 = self.pan_lat_conv[2](pan4)
p5 = self.pan_lat_conv[3](pan5)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
......@@ -424,4 +452,4 @@ class ASFBlock(nn.Layer):
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return paddle.concat(out_list, axis=1)
return paddle.concat(out_list, axis=1)
\ No newline at end of file
import paddle
from paddle import nn
# refer from: https://github.com/ViTAE-Transformer/I3CL/blob/736c80237f66d352d488e83b05f3e33c55201317/mmdet/models/detectors/intra_cl_module.py
class IntraCLBlock(nn.Layer):
def __init__(self, in_channels=96, reduce_factor=4):
super(IntraCLBlock, self).__init__()
self.channels = in_channels
self.rf = reduce_factor
weight_attr = paddle.nn.initializer.KaimingUniform()
self.conv1x1_reduce_channel = nn.Conv2D(
self.channels,
self.channels // self.rf,
kernel_size=1,
stride=1,
padding=0)
self.conv1x1_return_channel = nn.Conv2D(
self.channels // self.rf,
self.channels,
kernel_size=1,
stride=1,
padding=0)
self.v_layer_7x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 1),
stride=(1, 1),
padding=(3, 0))
self.v_layer_5x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 1),
stride=(1, 1),
padding=(2, 0))
self.v_layer_3x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 1),
stride=(1, 1),
padding=(1, 0))
self.q_layer_1x7 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 7),
stride=(1, 1),
padding=(0, 3))
self.q_layer_1x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 5),
stride=(1, 1),
padding=(0, 2))
self.q_layer_1x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 3),
stride=(1, 1),
padding=(0, 1))
# base
self.c_layer_7x7 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 7),
stride=(1, 1),
padding=(3, 3))
self.c_layer_5x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2))
self.c_layer_3x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.bn = nn.BatchNorm2D(self.channels)
self.relu = nn.ReLU()
def forward(self, x):
x_new = self.conv1x1_reduce_channel(x)
x_7_c = self.c_layer_7x7(x_new)
x_7_v = self.v_layer_7x1(x_new)
x_7_q = self.q_layer_1x7(x_new)
x_7 = x_7_c + x_7_v + x_7_q
x_5_c = self.c_layer_5x5(x_7)
x_5_v = self.v_layer_5x1(x_7)
x_5_q = self.q_layer_1x5(x_7)
x_5 = x_5_c + x_5_v + x_5_q
x_3_c = self.c_layer_3x3(x_5)
x_3_v = self.v_layer_3x1(x_5)
x_3_q = self.q_layer_1x3(x_5)
x_3 = x_3_c + x_3_v + x_3_q
x_relation = self.conv1x1_return_channel(x_3)
x_relation = self.bn(x_relation)
x_relation = self.relu(x_relation)
return x + x_relation
def build_intraclblock_list(num_block):
IntraCLBlock_list = nn.LayerList()
for i in range(num_block):
IntraCLBlock_list.append(IntraCLBlock())
return IntraCLBlock_list
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册