未验证 提交 79640f5d 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #6043 from LDOUBLEV/dygraph

add CAFPN and FEPAN
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/ch_PP-OCR_v3_det/
save_epoch_step: 100
eval_batch_step:
- 0
- 400
cal_metric_during_train: false
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
name: DistillationModel
algorithm: Distillation
model_type: det
Models:
Student:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Student2:
model_type: det
algorithm: DB
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: true
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Teacher:
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: ResNet
in_channels: 3
layers: 50
Neck:
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- CopyPaste:
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/ch_PP-OCR_V3_det/
save_epoch_step: 100
eval_batch_step:
- 0
- 400
cal_metric_during_train: false
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
...@@ -31,13 +31,14 @@ def get_bias_attr(k): ...@@ -31,13 +31,14 @@ def get_bias_attr(k):
class Head(nn.Layer): class Head(nn.Layer):
def __init__(self, in_channels, name_list): def __init__(self, in_channels, name_list, kernel_list=[3, 2, 2], **kwargs):
super(Head, self).__init__() super(Head, self).__init__()
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=3, kernel_size=kernel_list[0],
padding=1, padding=int(kernel_list[0] // 2),
weight_attr=ParamAttr(), weight_attr=ParamAttr(),
bias_attr=False) bias_attr=False)
self.conv_bn1 = nn.BatchNorm( self.conv_bn1 = nn.BatchNorm(
...@@ -50,7 +51,7 @@ class Head(nn.Layer): ...@@ -50,7 +51,7 @@ class Head(nn.Layer):
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=2, kernel_size=kernel_list[1],
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
...@@ -65,7 +66,7 @@ class Head(nn.Layer): ...@@ -65,7 +66,7 @@ class Head(nn.Layer):
self.conv3 = nn.Conv2DTranspose( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=1, out_channels=1,
kernel_size=2, kernel_size=kernel_list[2],
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
...@@ -100,8 +101,8 @@ class DBHead(nn.Layer): ...@@ -100,8 +101,8 @@ class DBHead(nn.Layer):
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50', 'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
'conv2d_transpose_3', 'thresh' 'conv2d_transpose_3', 'thresh'
] ]
self.binarize = Head(in_channels, binarize_name_list) self.binarize = Head(in_channels, binarize_name_list, **kwargs)
self.thresh = Head(in_channels, thresh_name_list) self.thresh = Head(in_channels, thresh_name_list, **kwargs)
def step_function(self, x, y): def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
......
...@@ -16,7 +16,7 @@ __all__ = ['build_neck'] ...@@ -16,7 +16,7 @@ __all__ = ['build_neck']
def build_neck(config): def build_neck(config):
from .db_fpn import DBFPN from .db_fpn import DBFPN, RSEFPN, LKPAN
from .east_fpn import EASTFPN from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
...@@ -26,8 +26,8 @@ def build_neck(config): ...@@ -26,8 +26,8 @@ def build_neck(config):
from .fce_fpn import FCEFPN from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN from .pren_fpn import PRENFPN
support_dict = [ support_dict = [
'FPN', 'FCEFPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'PGFPN', 'TableFPN', 'PRENFPN' 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
] ]
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -20,6 +20,88 @@ import paddle ...@@ -20,6 +20,88 @@ import paddle
from paddle import nn from paddle import nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule
class DSConv(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
stride=1,
groups=None,
if_act=True,
act="relu",
**kwargs):
super(DSConv, self).__init__()
if groups == None:
groups = in_channels
self.if_act = if_act
self.act = act
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None)
self.conv2 = nn.Conv2D(
in_channels=in_channels,
out_channels=int(in_channels * 4),
kernel_size=1,
stride=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None)
self.conv3 = nn.Conv2D(
in_channels=int(in_channels * 4),
out_channels=out_channels,
kernel_size=1,
stride=1,
bias_attr=False)
self._c = [in_channels, out_channels]
if in_channels != out_channels:
self.conv_end = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
bias_attr=False)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
x = self.conv3(x)
if self._c[0] != self._c[1]:
x = x + self.conv_end(inputs)
return x
class DBFPN(nn.Layer): class DBFPN(nn.Layer):
...@@ -106,3 +188,171 @@ class DBFPN(nn.Layer): ...@@ -106,3 +188,171 @@ class DBFPN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse return fuse
class RSELayer(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
super(RSELayer, self).__init__()
weight_attr = paddle.nn.initializer.KaimingUniform()
self.out_channels = out_channels
self.in_conv = nn.Conv2D(
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
padding=int(kernel_size // 2),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.se_block = SEModule(self.out_channels)
self.shortcut = shortcut
def forward(self, ins):
x = self.in_conv(ins)
if self.shortcut:
out = x + self.se_block(x)
else:
out = self.se_block(x)
return out
class RSEFPN(nn.Layer):
def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
super(RSEFPN, self).__init__()
self.out_channels = out_channels
self.ins_conv = nn.LayerList()
self.inp_conv = nn.LayerList()
for i in range(len(in_channels)):
self.ins_conv.append(
RSELayer(
in_channels[i],
out_channels,
kernel_size=1,
shortcut=shortcut))
self.inp_conv.append(
RSELayer(
out_channels,
out_channels // 4,
kernel_size=3,
shortcut=shortcut))
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.upsample(
in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
p5 = self.inp_conv[3](in5)
p4 = self.inp_conv[2](out4)
p3 = self.inp_conv[1](out3)
p2 = self.inp_conv[0](out2)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
class LKPAN(nn.Layer):
def __init__(self, in_channels, out_channels, mode='large', **kwargs):
super(LKPAN, self).__init__()
self.out_channels = out_channels
weight_attr = paddle.nn.initializer.KaimingUniform()
self.ins_conv = nn.LayerList()
self.inp_conv = nn.LayerList()
# pan head
self.pan_head_conv = nn.LayerList()
self.pan_lat_conv = nn.LayerList()
if mode.lower() == 'lite':
p_layer = DSConv
elif mode.lower() == 'large':
p_layer = nn.Conv2D
else:
raise ValueError(
"mode can only be one of ['lite', 'large'], but received {}".
format(mode))
for i in range(len(in_channels)):
self.ins_conv.append(
nn.Conv2D(
in_channels=in_channels[i],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False))
self.inp_conv.append(
p_layer(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False))
if i > 0:
self.pan_head_conv.append(
nn.Conv2D(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
stride=2,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False))
self.pan_lat_conv.append(
p_layer(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False))
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.upsample(
in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
f5 = self.inp_conv[3](in5)
f4 = self.inp_conv[2](out4)
f3 = self.inp_conv[1](out3)
f2 = self.inp_conv[0](out2)
pan3 = f3 + self.pan_head_conv[0](f2)
pan4 = f4 + self.pan_head_conv[1](pan3)
pan5 = f5 + self.pan_head_conv[2](pan4)
p2 = self.pan_lat_conv[0](f2)
p3 = self.pan_lat_conv[1](pan3)
p4 = self.pan_lat_conv[2](pan4)
p5 = self.pan_lat_conv[3](pan5)
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册