Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
26a89db7
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26a89db7
编写于
6月 27, 2022
作者:
W
wangjingyeye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add resnet
上级
142b5e9d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
403 addition
and
0 deletion
+403
-0
configs/det/det_r50_db++_td_tr.yml
configs/det/det_r50_db++_td_tr.yml
+166
-0
ppocr/modeling/backbones/det_resnet.py
ppocr/modeling/backbones/det_resnet.py
+237
-0
未找到文件。
configs/det/det_r50_db++_td_tr.yml
0 → 100644
浏览文件 @
26a89db7
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
1000
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/det_r50_td_tr/
save_epoch_step
:
200
eval_batch_step
:
-
0
-
2000
cal_metric_during_train
:
false
pretrained_model
:
./pretrain_models/synthtext_pretrained_res50_dcn_asf_spatial
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./checkpoints/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB
Transform
:
null
Backbone
:
name
:
ResNet
layers
:
50
dcn_stage
:
[
False
,
True
,
True
,
True
]
Neck
:
name
:
DBFPN
out_channels
:
256
use_asf
:
True
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
BCELoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
DecayLearningRate
learning_rate
:
0.007
epochs
:
1000
factor
:
0.9
end_lr
:
0
weight_decay
:
0.0001
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.5
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/TD_TR/TD500/train_gt_labels.txt
-
./train_data/TD_TR/TR400/gt_labels.txt
ratio_list
:
-
1.0
-
1.0
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
IaaAugment
:
augmenter_args
:
-
type
:
Fliplr
args
:
p
:
0.5
-
type
:
Affine
args
:
rotate
:
-
-10
-
10
-
type
:
Resize
args
:
size
:
-
0.5
-
3
-
EastRandomCropData
:
size
:
-
640
-
640
max_tries
:
10
keep_ratio
:
true
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
threshold_map
-
threshold_mask
-
shrink_map
-
shrink_mask
loader
:
shuffle
:
true
drop_last
:
false
batch_size_per_card
:
4
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/TD_TR/TD500/test_gt_labels.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
DetResizeForTest
:
image_shape
:
-
736
-
736
keep_ratio
:
True
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
shape
-
polys
-
ignore_tags
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
1
num_workers
:
2
profiler_options
:
null
ppocr/modeling/backbones/det_resnet.py
0 → 100644
浏览文件 @
26a89db7
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2D
,
MaxPool2D
,
AvgPool2D
from
paddle.nn.initializer
import
Uniform
import
math
from
paddle.vision.ops
import
DeformConv2D
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
,
Constant
,
XavierUniform
from
.det_resnet_vd
import
DeformableConvV2
,
ConvBNLayer
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
,
is_dcn
=
False
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
act
=
"relu"
,
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
act
=
"relu"
,
is_dcn
=
is_dcn
,
dcn_groups
=
1
,
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
act
=
None
,
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
stride
=
stride
,
)
self
.
shortcut
=
shortcut
self
.
_num_channels_out
=
num_filters
*
4
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
act
=
"relu"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
act
=
None
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
stride
=
stride
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv1
)
y
=
F
.
relu
(
y
)
return
y
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
out_indices
=
None
,
dcn_stage
=
None
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
self
.
input_image_channel
=
in_channels
supported_layers
=
[
18
,
34
,
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
if
layers
>=
50
else
[
64
,
64
,
128
,
256
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
dcn_stage
=
dcn_stage
if
dcn_stage
is
not
None
else
[
False
,
False
,
False
,
False
]
self
.
out_indices
=
out_indices
if
out_indices
is
not
None
else
[
0
,
1
,
2
,
3
]
self
.
conv
=
ConvBNLayer
(
in_channels
=
self
.
input_image_channel
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
act
=
"relu"
,
)
self
.
pool2d_max
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
self
.
stages
=
[]
self
.
out_channels
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
block_list
=
[]
is_dcn
=
self
.
dcn_stage
[
block
]
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
bottleneck_block
=
self
.
add_sublayer
(
conv_name
,
BottleneckBlock
(
num_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
is_dcn
=
is_dcn
))
block_list
.
append
(
bottleneck_block
)
shortcut
=
True
if
block
in
self
.
out_indices
:
self
.
out_channels
.
append
(
num_filters
[
block
]
*
4
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
else
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
block_list
=
[]
# is_dcn = self.dcn_stage[block]
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
basic_block
=
self
.
add_sublayer
(
conv_name
,
BasicBlock
(
num_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
))
block_list
.
append
(
basic_block
)
shortcut
=
True
if
block
in
self
.
out_indices
:
self
.
out_channels
.
append
(
num_filters
[
block
])
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
def
forward
(
self
,
inputs
):
y
=
self
.
conv
(
inputs
)
y
=
self
.
pool2d_max
(
y
)
out
=
[]
for
i
,
block
in
enumerate
(
self
.
stages
):
y
=
block
(
y
)
if
i
in
self
.
out_indices
:
out
.
append
(
y
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录