Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
79640f5d
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
79640f5d
编写于
4月 28, 2022
作者:
D
Double_V
提交者:
GitHub
4月 28, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6043 from LDOUBLEV/dygraph
add CAFPN and FEPAN
上级
da8991ef
d341e182
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
658 addition
and
10 deletion
+658
-10
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
+234
-0
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml
+163
-0
ppocr/modeling/heads/det_db_head.py
ppocr/modeling/heads/det_db_head.py
+8
-7
ppocr/modeling/necks/__init__.py
ppocr/modeling/necks/__init__.py
+3
-3
ppocr/modeling/necks/db_fpn.py
ppocr/modeling/necks/db_fpn.py
+250
-0
未找到文件。
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
0 → 100644
浏览文件 @
79640f5d
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
500
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/ch_PP-OCR_v3_det/
save_epoch_step
:
100
eval_batch_step
:
-
0
-
400
cal_metric_during_train
:
false
pretrained_model
:
null
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./checkpoints/det_db/predicts_db.txt
distributed
:
true
Architecture
:
name
:
DistillationModel
algorithm
:
Distillation
model_type
:
det
Models
:
Student
:
model_type
:
det
algorithm
:
DB
Transform
:
null
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
true
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
Student2
:
model_type
:
det
algorithm
:
DB
Transform
:
null
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
true
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
Teacher
:
freeze_params
:
true
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
ResNet
in_channels
:
3
layers
:
50
Neck
:
name
:
LKPAN
out_channels
:
256
Head
:
name
:
DBHead
kernel_list
:
[
7
,
2
,
2
]
k
:
50
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDilaDBLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
[
"
Student2"
,
"
Teacher"
]
key
:
maps
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
-
DistillationDMLLoss
:
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
# act: None
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
key
:
maps
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
# key: maps
# name: DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
L2
factor
:
5.0e-05
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
]
key
:
head_out
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DistillationMetric
base_metric_name
:
DetMetric
main_indicator
:
hmean
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
CopyPaste
:
-
IaaAugment
:
augmenter_args
:
-
type
:
Fliplr
args
:
p
:
0.5
-
type
:
Affine
args
:
rotate
:
-
-10
-
10
-
type
:
Resize
args
:
size
:
-
0.5
-
3
-
EastRandomCropData
:
size
:
-
960
-
960
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.485
-
0.456
-
0.406
std
:
-
0.229
-
0.224
-
0.225
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
threshold_map
-
threshold_mask
-
shrink_map
-
shrink_mask
loader
:
shuffle
:
true
drop_last
:
false
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
DetResizeForTest
:
null
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.485
-
0.456
-
0.406
std
:
-
0.229
-
0.224
-
0.225
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
shape
-
polys
-
ignore_tags
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
1
num_workers
:
2
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml
0 → 100644
浏览文件 @
79640f5d
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
500
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/ch_PP-OCR_V3_det/
save_epoch_step
:
100
eval_batch_step
:
-
0
-
400
cal_metric_during_train
:
false
pretrained_model
:
null
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./checkpoints/det_db/predicts_db.txt
distributed
:
true
Architecture
:
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
L2
factor
:
5.0e-05
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
IaaAugment
:
augmenter_args
:
-
type
:
Fliplr
args
:
p
:
0.5
-
type
:
Affine
args
:
rotate
:
-
-10
-
10
-
type
:
Resize
args
:
size
:
-
0.5
-
3
-
EastRandomCropData
:
size
:
-
960
-
960
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.485
-
0.456
-
0.406
std
:
-
0.229
-
0.224
-
0.225
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
threshold_map
-
threshold_mask
-
shrink_map
-
shrink_mask
loader
:
shuffle
:
true
drop_last
:
false
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
DetResizeForTest
:
null
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.485
-
0.456
-
0.406
std
:
-
0.229
-
0.224
-
0.225
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
shape
-
polys
-
ignore_tags
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
1
num_workers
:
2
ppocr/modeling/heads/det_db_head.py
浏览文件 @
79640f5d
...
@@ -31,13 +31,14 @@ def get_bias_attr(k):
...
@@ -31,13 +31,14 @@ def get_bias_attr(k):
class
Head
(
nn
.
Layer
):
class
Head
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
name_list
):
def
__init__
(
self
,
in_channels
,
name_list
,
kernel_list
=
[
3
,
2
,
2
],
**
kwargs
):
super
(
Head
,
self
).
__init__
()
super
(
Head
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
in_channels
//
4
,
out_channels
=
in_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
kernel_list
[
0
]
,
padding
=
1
,
padding
=
int
(
kernel_list
[
0
]
//
2
)
,
weight_attr
=
ParamAttr
(),
weight_attr
=
ParamAttr
(),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
conv_bn1
=
nn
.
BatchNorm
(
self
.
conv_bn1
=
nn
.
BatchNorm
(
...
@@ -50,7 +51,7 @@ class Head(nn.Layer):
...
@@ -50,7 +51,7 @@ class Head(nn.Layer):
self
.
conv2
=
nn
.
Conv2DTranspose
(
self
.
conv2
=
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
//
4
,
in_channels
=
in_channels
//
4
,
out_channels
=
in_channels
//
4
,
out_channels
=
in_channels
//
4
,
kernel_size
=
2
,
kernel_size
=
kernel_list
[
1
]
,
stride
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
...
@@ -65,7 +66,7 @@ class Head(nn.Layer):
...
@@ -65,7 +66,7 @@ class Head(nn.Layer):
self
.
conv3
=
nn
.
Conv2DTranspose
(
self
.
conv3
=
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
//
4
,
in_channels
=
in_channels
//
4
,
out_channels
=
1
,
out_channels
=
1
,
kernel_size
=
2
,
kernel_size
=
kernel_list
[
2
]
,
stride
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
...
@@ -100,8 +101,8 @@ class DBHead(nn.Layer):
...
@@ -100,8 +101,8 @@ class DBHead(nn.Layer):
'conv2d_57'
,
'batch_norm_49'
,
'conv2d_transpose_2'
,
'batch_norm_50'
,
'conv2d_57'
,
'batch_norm_49'
,
'conv2d_transpose_2'
,
'batch_norm_50'
,
'conv2d_transpose_3'
,
'thresh'
'conv2d_transpose_3'
,
'thresh'
]
]
self
.
binarize
=
Head
(
in_channels
,
binarize_name_list
)
self
.
binarize
=
Head
(
in_channels
,
binarize_name_list
,
**
kwargs
)
self
.
thresh
=
Head
(
in_channels
,
thresh_name_list
)
self
.
thresh
=
Head
(
in_channels
,
thresh_name_list
,
**
kwargs
)
def
step_function
(
self
,
x
,
y
):
def
step_function
(
self
,
x
,
y
):
return
paddle
.
reciprocal
(
1
+
paddle
.
exp
(
-
self
.
k
*
(
x
-
y
)))
return
paddle
.
reciprocal
(
1
+
paddle
.
exp
(
-
self
.
k
*
(
x
-
y
)))
...
...
ppocr/modeling/necks/__init__.py
浏览文件 @
79640f5d
...
@@ -16,7 +16,7 @@ __all__ = ['build_neck']
...
@@ -16,7 +16,7 @@ __all__ = ['build_neck']
def
build_neck
(
config
):
def
build_neck
(
config
):
from
.db_fpn
import
DBFPN
from
.db_fpn
import
DBFPN
,
RSEFPN
,
LKPAN
from
.east_fpn
import
EASTFPN
from
.east_fpn
import
EASTFPN
from
.sast_fpn
import
SASTFPN
from
.sast_fpn
import
SASTFPN
from
.rnn
import
SequenceEncoder
from
.rnn
import
SequenceEncoder
...
@@ -26,8 +26,8 @@ def build_neck(config):
...
@@ -26,8 +26,8 @@ def build_neck(config):
from
.fce_fpn
import
FCEFPN
from
.fce_fpn
import
FCEFPN
from
.pren_fpn
import
PRENFPN
from
.pren_fpn
import
PRENFPN
support_dict
=
[
support_dict
=
[
'FPN'
,
'FCEFPN'
,
'
DBFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'SequenceEncoder
'
,
'FPN'
,
'FCEFPN'
,
'
LKPAN'
,
'DBFPN'
,
'RSEFPN'
,
'EASTFPN'
,
'SASTFPN
'
,
'PGFPN'
,
'TableFPN'
,
'PRENFPN'
'
SequenceEncoder'
,
'
PGFPN'
,
'TableFPN'
,
'PRENFPN'
]
]
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/modeling/necks/db_fpn.py
浏览文件 @
79640f5d
...
@@ -20,6 +20,88 @@ import paddle
...
@@ -20,6 +20,88 @@ import paddle
from
paddle
import
nn
from
paddle
import
nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle
import
ParamAttr
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../../..'
)))
from
ppocr.modeling.backbones.det_mobilenet_v3
import
SEModule
class
DSConv
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
padding
,
stride
=
1
,
groups
=
None
,
if_act
=
True
,
act
=
"relu"
,
**
kwargs
):
super
(
DSConv
,
self
).
__init__
()
if
groups
==
None
:
groups
=
in_channels
self
.
if_act
=
if_act
self
.
act
=
act
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm
(
num_channels
=
in_channels
,
act
=
None
)
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
int
(
in_channels
*
4
),
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm
(
num_channels
=
int
(
in_channels
*
4
),
act
=
None
)
self
.
conv3
=
nn
.
Conv2D
(
in_channels
=
int
(
in_channels
*
4
),
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
False
)
self
.
_c
=
[
in_channels
,
out_channels
]
if
in_channels
!=
out_channels
:
self
.
conv_end
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
False
)
def
forward
(
self
,
inputs
):
x
=
self
.
conv1
(
inputs
)
x
=
self
.
bn1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
if
self
.
if_act
:
if
self
.
act
==
"relu"
:
x
=
F
.
relu
(
x
)
elif
self
.
act
==
"hardswish"
:
x
=
F
.
hardswish
(
x
)
else
:
print
(
"The activation function({}) is selected incorrectly."
.
format
(
self
.
act
))
exit
()
x
=
self
.
conv3
(
x
)
if
self
.
_c
[
0
]
!=
self
.
_c
[
1
]:
x
=
x
+
self
.
conv_end
(
inputs
)
return
x
class
DBFPN
(
nn
.
Layer
):
class
DBFPN
(
nn
.
Layer
):
...
@@ -106,3 +188,171 @@ class DBFPN(nn.Layer):
...
@@ -106,3 +188,171 @@ class DBFPN(nn.Layer):
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
return
fuse
return
fuse
class
RSELayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
shortcut
=
True
):
super
(
RSELayer
,
self
).
__init__
()
weight_attr
=
paddle
.
nn
.
initializer
.
KaimingUniform
()
self
.
out_channels
=
out_channels
self
.
in_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
kernel_size
,
padding
=
int
(
kernel_size
//
2
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
se_block
=
SEModule
(
self
.
out_channels
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
ins
):
x
=
self
.
in_conv
(
ins
)
if
self
.
shortcut
:
out
=
x
+
self
.
se_block
(
x
)
else
:
out
=
self
.
se_block
(
x
)
return
out
class
RSEFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
shortcut
=
True
,
**
kwargs
):
super
(
RSEFPN
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
ins_conv
=
nn
.
LayerList
()
self
.
inp_conv
=
nn
.
LayerList
()
for
i
in
range
(
len
(
in_channels
)):
self
.
ins_conv
.
append
(
RSELayer
(
in_channels
[
i
],
out_channels
,
kernel_size
=
1
,
shortcut
=
shortcut
))
self
.
inp_conv
.
append
(
RSELayer
(
out_channels
,
out_channels
//
4
,
kernel_size
=
3
,
shortcut
=
shortcut
))
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
in5
=
self
.
ins_conv
[
3
](
c5
)
in4
=
self
.
ins_conv
[
2
](
c4
)
in3
=
self
.
ins_conv
[
1
](
c3
)
in2
=
self
.
ins_conv
[
0
](
c2
)
out4
=
in4
+
F
.
upsample
(
in5
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/16
out3
=
in3
+
F
.
upsample
(
out4
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/8
out2
=
in2
+
F
.
upsample
(
out3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/4
p5
=
self
.
inp_conv
[
3
](
in5
)
p4
=
self
.
inp_conv
[
2
](
out4
)
p3
=
self
.
inp_conv
[
1
](
out3
)
p2
=
self
.
inp_conv
[
0
](
out2
)
p5
=
F
.
upsample
(
p5
,
scale_factor
=
8
,
mode
=
"nearest"
,
align_mode
=
1
)
p4
=
F
.
upsample
(
p4
,
scale_factor
=
4
,
mode
=
"nearest"
,
align_mode
=
1
)
p3
=
F
.
upsample
(
p3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
return
fuse
class
LKPAN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
mode
=
'large'
,
**
kwargs
):
super
(
LKPAN
,
self
).
__init__
()
self
.
out_channels
=
out_channels
weight_attr
=
paddle
.
nn
.
initializer
.
KaimingUniform
()
self
.
ins_conv
=
nn
.
LayerList
()
self
.
inp_conv
=
nn
.
LayerList
()
# pan head
self
.
pan_head_conv
=
nn
.
LayerList
()
self
.
pan_lat_conv
=
nn
.
LayerList
()
if
mode
.
lower
()
==
'lite'
:
p_layer
=
DSConv
elif
mode
.
lower
()
==
'large'
:
p_layer
=
nn
.
Conv2D
else
:
raise
ValueError
(
"mode can only be one of ['lite', 'large'], but received {}"
.
format
(
mode
))
for
i
in
range
(
len
(
in_channels
)):
self
.
ins_conv
.
append
(
nn
.
Conv2D
(
in_channels
=
in_channels
[
i
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
))
self
.
inp_conv
.
append
(
p_layer
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
9
,
padding
=
4
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
))
if
i
>
0
:
self
.
pan_head_conv
.
append
(
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
stride
=
2
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
))
self
.
pan_lat_conv
.
append
(
p_layer
(
in_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
9
,
padding
=
4
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
))
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
in5
=
self
.
ins_conv
[
3
](
c5
)
in4
=
self
.
ins_conv
[
2
](
c4
)
in3
=
self
.
ins_conv
[
1
](
c3
)
in2
=
self
.
ins_conv
[
0
](
c2
)
out4
=
in4
+
F
.
upsample
(
in5
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/16
out3
=
in3
+
F
.
upsample
(
out4
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/8
out2
=
in2
+
F
.
upsample
(
out3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/4
f5
=
self
.
inp_conv
[
3
](
in5
)
f4
=
self
.
inp_conv
[
2
](
out4
)
f3
=
self
.
inp_conv
[
1
](
out3
)
f2
=
self
.
inp_conv
[
0
](
out2
)
pan3
=
f3
+
self
.
pan_head_conv
[
0
](
f2
)
pan4
=
f4
+
self
.
pan_head_conv
[
1
](
pan3
)
pan5
=
f5
+
self
.
pan_head_conv
[
2
](
pan4
)
p2
=
self
.
pan_lat_conv
[
0
](
f2
)
p3
=
self
.
pan_lat_conv
[
1
](
pan3
)
p4
=
self
.
pan_lat_conv
[
2
](
pan4
)
p5
=
self
.
pan_lat_conv
[
3
](
pan5
)
p5
=
F
.
upsample
(
p5
,
scale_factor
=
8
,
mode
=
"nearest"
,
align_mode
=
1
)
p4
=
F
.
upsample
(
p4
,
scale_factor
=
4
,
mode
=
"nearest"
,
align_mode
=
1
)
p3
=
F
.
upsample
(
p3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
return
fuse
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录