Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
140a50df
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
140a50df
编写于
7月 01, 2022
作者:
D
Double_V
提交者:
GitHub
7月 01, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6717 from wangjingyeye/dyg_db
add db++
上级
0b4ccc38
00e67e0b
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
741 addition
and
25 deletion
+741
-25
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+1
-1
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
+1
-1
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
+1
-1
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
+2
-2
configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
+1
-1
configs/det/det_r50_db++_ic15.yml
configs/det/det_r50_db++_ic15.yml
+163
-0
configs/det/det_r50_db++_td_tr.yml
configs/det/det_r50_db++_td_tr.yml
+166
-0
configs/det/det_r50_vd_db.yml
configs/det/det_r50_vd_db.yml
+1
-1
configs/det/det_r50_vd_dcn_fce_ctw.yml
configs/det/det_r50_vd_dcn_fce_ctw.yml
+1
-1
configs/det/det_r50_vd_east.yml
configs/det/det_r50_vd_east.yml
+1
-1
configs/det/det_r50_vd_pse.yml
configs/det/det_r50_vd_pse.yml
+1
-1
configs/det/det_res18_db_v2.0.yml
configs/det/det_res18_db_v2.0.yml
+1
-1
doc/doc_ch/algorithm_det_db.md
doc/doc_ch/algorithm_det_db.md
+23
-3
doc/doc_ch/dataset/ocr_datasets.md
doc/doc_ch/dataset/ocr_datasets.md
+1
-0
ppocr/data/imaug/operators.py
ppocr/data/imaug/operators.py
+7
-0
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-2
ppocr/modeling/backbones/det_resnet.py
ppocr/modeling/backbones/det_resnet.py
+236
-0
ppocr/modeling/backbones/det_resnet_vd.py
ppocr/modeling/backbones/det_resnet_vd.py
+7
-6
ppocr/modeling/necks/db_fpn.py
ppocr/modeling/necks/db_fpn.py
+70
-1
ppocr/optimizer/learning_rate.py
ppocr/optimizer/learning_rate.py
+35
-0
tools/infer/predict_det.py
tools/infer/predict_det.py
+18
-1
tools/program.py
tools/program.py
+1
-1
未找到文件。
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
浏览文件 @
140a50df
...
...
@@ -28,7 +28,7 @@ Architecture:
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
18
Neck
:
name
:
DBFPN
...
...
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
浏览文件 @
140a50df
...
...
@@ -45,7 +45,7 @@ Architecture:
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
18
Neck
:
name
:
DBFPN
...
...
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
浏览文件 @
140a50df
...
...
@@ -61,7 +61,7 @@ Architecture:
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
in_channels
:
3
layers
:
50
Neck
:
...
...
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
浏览文件 @
140a50df
...
...
@@ -25,7 +25,7 @@ Architecture:
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
in_channels
:
3
layers
:
50
Neck
:
...
...
@@ -40,7 +40,7 @@ Architecture:
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
in_channels
:
3
layers
:
50
Neck
:
...
...
configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
浏览文件 @
140a50df
...
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
18
disable_se
:
True
Neck
:
...
...
configs/det/det_r50_db++_ic15.yml
0 → 100644
浏览文件 @
140a50df
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
1000
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/det_r50_icdar15/
save_epoch_step
:
200
eval_batch_step
:
-
0
-
2000
cal_metric_during_train
:
false
pretrained_model
:
./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./checkpoints/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB++
Transform
:
null
Backbone
:
name
:
ResNet
layers
:
50
dcn_stage
:
[
False
,
True
,
True
,
True
]
Neck
:
name
:
DBFPN
out_channels
:
256
use_asf
:
True
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
BCELoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
DecayLearningRate
learning_rate
:
0.007
epochs
:
1000
factor
:
0.9
end_lr
:
0
weight_decay
:
0.0001
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
-
1.0
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
IaaAugment
:
augmenter_args
:
-
type
:
Fliplr
args
:
p
:
0.5
-
type
:
Affine
args
:
rotate
:
-
-10
-
10
-
type
:
Resize
args
:
size
:
-
0.5
-
3
-
EastRandomCropData
:
size
:
-
640
-
640
max_tries
:
10
keep_ratio
:
true
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
threshold_map
-
threshold_mask
-
shrink_map
-
shrink_mask
loader
:
shuffle
:
true
drop_last
:
false
batch_size_per_card
:
4
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
DetResizeForTest
:
image_shape
:
-
1152
-
2048
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
shape
-
polys
-
ignore_tags
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
1
num_workers
:
2
profiler_options
:
null
configs/det/det_r50_db++_td_tr.yml
0 → 100644
浏览文件 @
140a50df
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
1000
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/det_r50_td_tr/
save_epoch_step
:
200
eval_batch_step
:
-
0
-
2000
cal_metric_during_train
:
false
pretrained_model
:
./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./checkpoints/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB++
Transform
:
null
Backbone
:
name
:
ResNet
layers
:
50
dcn_stage
:
[
False
,
True
,
True
,
True
]
Neck
:
name
:
DBFPN
out_channels
:
256
use_asf
:
True
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
BCELoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
DecayLearningRate
learning_rate
:
0.007
epochs
:
1000
factor
:
0.9
end_lr
:
0
weight_decay
:
0.0001
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.5
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/TD_TR/TD500/train_gt_labels.txt
-
./train_data/TD_TR/TR400/gt_labels.txt
ratio_list
:
-
1.0
-
1.0
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
IaaAugment
:
augmenter_args
:
-
type
:
Fliplr
args
:
p
:
0.5
-
type
:
Affine
args
:
rotate
:
-
-10
-
10
-
type
:
Resize
args
:
size
:
-
0.5
-
3
-
EastRandomCropData
:
size
:
-
640
-
640
max_tries
:
10
keep_ratio
:
true
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
threshold_map
-
threshold_mask
-
shrink_map
-
shrink_mask
loader
:
shuffle
:
true
drop_last
:
false
batch_size_per_card
:
4
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/TD_TR/TD500/test_gt_labels.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
DetLabelEncode
:
null
-
DetResizeForTest
:
image_shape
:
-
736
-
736
keep_ratio
:
True
-
NormalizeImage
:
scale
:
1./255.
mean
:
-
0.48109378172549
-
0.45752457890196
-
0.40787054090196
std
:
-
1.0
-
1.0
-
1.0
order
:
hwc
-
ToCHWImage
:
null
-
KeepKeys
:
keep_keys
:
-
image
-
shape
-
polys
-
ignore_tags
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
1
num_workers
:
2
profiler_options
:
null
configs/det/det_r50_vd_db.yml
浏览文件 @
140a50df
...
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
50
Neck
:
name
:
DBFPN
...
...
configs/det/det_r50_vd_dcn_fce_ctw.yml
浏览文件 @
140a50df
...
...
@@ -21,7 +21,7 @@ Architecture:
algorithm
:
FCE
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
50
dcn_stage
:
[
False
,
True
,
True
,
True
]
out_indices
:
[
1
,
2
,
3
]
...
...
configs/det/det_r50_vd_east.yml
浏览文件 @
140a50df
...
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
EAST
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
50
Neck
:
name
:
EASTFPN
...
...
configs/det/det_r50_vd_pse.yml
浏览文件 @
140a50df
...
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
PSE
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
50
Neck
:
name
:
FPN
...
...
configs/det/det_res18_db_v2.0.yml
浏览文件 @
140a50df
...
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
name
:
ResNet
_vd
layers
:
18
disable_se
:
True
Neck
:
...
...
doc/doc_ch/algorithm_det_db.md
浏览文件 @
140a50df
# DB
# DB
与DB++
-
[
1. 算法简介
](
#1
)
-
[
2. 环境配置
](
#2
)
...
...
@@ -21,12 +21,24 @@
> Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang
> AAAI, 2020
> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
> Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang
> TPAMI, 2022
在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|
[
configs/det/det_r50_vd_db.yml
](
../../configs/det/det_r50_vd_db.yml
)
|86.41%|78.72%|82.38%|
[
训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar
)
|
|DB|MobileNetV3|
[
configs/det/det_mv3_db.yml
](
../../configs/det/det_mv3_db.yml
)
|77.29%|73.08%|75.12%|
[
训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
)
|
|DB++|ResNet50|
[
configs/det/det_r50_db++_ic15.yml
](
../../configs/det/det_r50_db++_ic15.yml
)
|90.89%|82.66%|86.58%|
[
合成数据预训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar
)
|
在TD_TR文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB++|ResNet50|
[
configs/det/det_r50_db++_td_tr.yml
](
../../configs/det/det_r50_db++_td_tr.yml
)
|92.92%|86.48%|89.58%|
[
合成数据预训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_td_tr_train.tar
)
|
<a
name=
"2"
></a>
...
...
@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai
DB文本检测模型推理,可以执行如下命令:
```
shell
python3 tools/infer/predict_det.py
--image_dir
=
"./doc/imgs_en/img_10.jpg"
--det_model_dir
=
"./inference/det_db/"
python3 tools/infer/predict_det.py
--image_dir
=
"./doc/imgs_en/img_10.jpg"
--det_model_dir
=
"./inference/det_db/"
--det_algorithm
=
"DB"
```
可视化文本检测结果默认保存到
`./inference_results`
文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
...
...
@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式:
pages
=
{11474--11481}
,
year
=
{2020}
}
```
\ No newline at end of file
@article
{
liao2022real
,
title
=
{Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion}
,
author
=
{Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang}
,
journal
=
{IEEE Transactions on Pattern Analysis and Machine Intelligence}
,
year
=
{2022}
,
publisher
=
{IEEE}
}
```
doc/doc_ch/dataset/ocr_datasets.md
浏览文件 @
140a50df
...
...
@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
| ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads|
[
train
](
https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt
)
/
[
test
](
https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt
)
|
| ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 |
| total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 |
| td tr |https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar| 图片下载地址中已包含 |
#### 1.2.1 ICDAR 2015
ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。
...
...
ppocr/data/imaug/operators.py
浏览文件 @
140a50df
...
...
@@ -205,9 +205,12 @@ class DetResizeForTest(object):
def
__init__
(
self
,
**
kwargs
):
super
(
DetResizeForTest
,
self
).
__init__
()
self
.
resize_type
=
0
self
.
keep_ratio
=
False
if
'image_shape'
in
kwargs
:
self
.
image_shape
=
kwargs
[
'image_shape'
]
self
.
resize_type
=
1
if
'keep_ratio'
in
kwargs
:
self
.
keep_ratio
=
kwargs
[
'keep_ratio'
]
elif
'limit_side_len'
in
kwargs
:
self
.
limit_side_len
=
kwargs
[
'limit_side_len'
]
self
.
limit_type
=
kwargs
.
get
(
'limit_type'
,
'min'
)
...
...
@@ -237,6 +240,10 @@ class DetResizeForTest(object):
def
resize_image_type1
(
self
,
img
):
resize_h
,
resize_w
=
self
.
image_shape
ori_h
,
ori_w
=
img
.
shape
[:
2
]
# (h, w, c)
if
self
.
keep_ratio
is
True
:
resize_w
=
ori_w
*
resize_h
/
ori_h
N
=
math
.
ceil
(
resize_w
/
32
)
resize_w
=
N
*
32
ratio_h
=
float
(
resize_h
)
/
ori_h
ratio_w
=
float
(
resize_w
)
/
ori_w
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
...
...
ppocr/modeling/backbones/__init__.py
浏览文件 @
140a50df
...
...
@@ -18,9 +18,10 @@ __all__ = ["build_backbone"]
def
build_backbone
(
config
,
model_type
):
if
model_type
==
"det"
or
model_type
==
"table"
:
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_resnet_vd
import
ResNet
from
.det_resnet
import
ResNet
from
.det_resnet_vd
import
ResNet_vd
from
.det_resnet_vd_sast
import
ResNet_SAST
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_SAST"
]
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_
vd"
,
"ResNet_
SAST"
]
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_resnet_vd
import
ResNet
...
...
ppocr/modeling/backbones/det_resnet.py
0 → 100644
浏览文件 @
140a50df
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2D
,
MaxPool2D
,
AvgPool2D
from
paddle.nn.initializer
import
Uniform
import
math
from
paddle.vision.ops
import
DeformConv2D
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
,
Constant
,
XavierUniform
from
.det_resnet_vd
import
DeformableConvV2
,
ConvBNLayer
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
,
is_dcn
=
False
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
act
=
"relu"
,
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
act
=
"relu"
,
is_dcn
=
is_dcn
,
dcn_groups
=
1
,
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
act
=
None
,
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
stride
=
stride
,
)
self
.
shortcut
=
shortcut
self
.
_num_channels_out
=
num_filters
*
4
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
act
=
"relu"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
act
=
None
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
stride
=
stride
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv1
)
y
=
F
.
relu
(
y
)
return
y
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
out_indices
=
None
,
dcn_stage
=
None
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
self
.
input_image_channel
=
in_channels
supported_layers
=
[
18
,
34
,
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
if
layers
>=
50
else
[
64
,
64
,
128
,
256
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
dcn_stage
=
dcn_stage
if
dcn_stage
is
not
None
else
[
False
,
False
,
False
,
False
]
self
.
out_indices
=
out_indices
if
out_indices
is
not
None
else
[
0
,
1
,
2
,
3
]
self
.
conv
=
ConvBNLayer
(
in_channels
=
self
.
input_image_channel
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
act
=
"relu"
,
)
self
.
pool2d_max
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
self
.
stages
=
[]
self
.
out_channels
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
block_list
=
[]
is_dcn
=
self
.
dcn_stage
[
block
]
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
bottleneck_block
=
self
.
add_sublayer
(
conv_name
,
BottleneckBlock
(
num_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
is_dcn
=
is_dcn
))
block_list
.
append
(
bottleneck_block
)
shortcut
=
True
if
block
in
self
.
out_indices
:
self
.
out_channels
.
append
(
num_filters
[
block
]
*
4
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
else
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
block_list
=
[]
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
basic_block
=
self
.
add_sublayer
(
conv_name
,
BasicBlock
(
num_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
))
block_list
.
append
(
basic_block
)
shortcut
=
True
if
block
in
self
.
out_indices
:
self
.
out_channels
.
append
(
num_filters
[
block
])
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
def
forward
(
self
,
inputs
):
y
=
self
.
conv
(
inputs
)
y
=
self
.
pool2d_max
(
y
)
out
=
[]
for
i
,
block
in
enumerate
(
self
.
stages
):
y
=
block
(
y
)
if
i
in
self
.
out_indices
:
out
.
append
(
y
)
return
out
ppocr/modeling/backbones/det_resnet_vd.py
浏览文件 @
140a50df
...
...
@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
,
Constant
,
XavierUniform
__all__
=
[
"ResNet"
]
__all__
=
[
"ResNet
_vd"
,
"ConvBNLayer"
,
"DeformableConvV2
"
]
class
DeformableConvV2
(
nn
.
Layer
):
...
...
@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size
,
stride
=
1
,
groups
=
1
,
dcn_groups
=
1
,
is_vd_mode
=
False
,
act
=
None
,
is_dcn
=
False
):
...
...
@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
2
,
#groups,
groups
=
dcn_groups
,
#groups,
bias_attr
=
False
)
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
)
...
...
@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
is_dcn
=
is_dcn
)
is_dcn
=
is_dcn
,
dcn_groups
=
2
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
*
4
,
...
...
@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return
y
class
ResNet
(
nn
.
Layer
):
class
ResNet
_vd
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
dcn_stage
=
None
,
out_indices
=
None
,
**
kwargs
):
super
(
ResNet
,
self
).
__init__
()
super
(
ResNet
_vd
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
,
200
]
...
...
@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for
block
in
range
(
len
(
depth
)):
block_list
=
[]
shortcut
=
False
# is_dcn = self.dcn_stage[block]
for
i
in
range
(
depth
[
block
]):
basic_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
...
...
ppocr/modeling/necks/db_fpn.py
浏览文件 @
140a50df
...
...
@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class
DBFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
use_asf
=
False
,
**
kwargs
):
super
(
DBFPN
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
use_asf
=
use_asf
weight_attr
=
paddle
.
nn
.
initializer
.
KaimingUniform
()
self
.
in2_conv
=
nn
.
Conv2D
(
...
...
@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
if
self
.
use_asf
is
True
:
self
.
asf
=
ASFBlock
(
self
.
out_channels
,
self
.
out_channels
//
4
)
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
...
...
@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3
=
F
.
upsample
(
p3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
if
self
.
use_asf
is
True
:
fuse
=
self
.
asf
(
fuse
,
[
p5
,
p4
,
p3
,
p2
])
return
fuse
...
...
@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
return
fuse
class
ASFBlock
(
nn
.
Layer
):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def
__init__
(
self
,
in_channels
,
inter_channels
,
out_features_num
=
4
):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super
(
ASFBlock
,
self
).
__init__
()
weight_attr
=
paddle
.
nn
.
initializer
.
KaimingUniform
()
self
.
in_channels
=
in_channels
self
.
inter_channels
=
inter_channels
self
.
out_features_num
=
out_features_num
self
.
conv
=
nn
.
Conv2D
(
in_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
spatial_scale
=
nn
.
Sequential
(
#Nx1xHxW
nn
.
Conv2D
(
in_channels
=
1
,
out_channels
=
1
,
kernel_size
=
3
,
bias_attr
=
False
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
)),
nn
.
ReLU
(),
nn
.
Conv2D
(
in_channels
=
1
,
out_channels
=
1
,
kernel_size
=
1
,
bias_attr
=
False
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
)),
nn
.
Sigmoid
())
self
.
channel_scale
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
inter_channels
,
out_channels
=
out_features_num
,
kernel_size
=
1
,
bias_attr
=
False
,
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
)),
nn
.
Sigmoid
())
def
forward
(
self
,
fuse_features
,
features_list
):
fuse_features
=
self
.
conv
(
fuse_features
)
spatial_x
=
paddle
.
mean
(
fuse_features
,
axis
=
1
,
keepdim
=
True
)
attention_scores
=
self
.
spatial_scale
(
spatial_x
)
+
fuse_features
attention_scores
=
self
.
channel_scale
(
attention_scores
)
assert
len
(
features_list
)
==
self
.
out_features_num
out_list
=
[]
for
i
in
range
(
self
.
out_features_num
):
out_list
.
append
(
attention_scores
[:,
i
:
i
+
1
]
*
features_list
[
i
])
return
paddle
.
concat
(
out_list
,
axis
=
1
)
ppocr/optimizer/learning_rate.py
浏览文件 @
140a50df
...
...
@@ -308,3 +308,38 @@ class Const(object):
end_lr
=
self
.
learning_rate
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
class
DecayLearningRate
(
object
):
"""
DecayLearningRate learning rate decay
new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
Args:
learning_rate(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
end_lr(float): The minimum final learning rate. Default: 0.0.
"""
def
__init__
(
self
,
learning_rate
,
step_each_epoch
,
epochs
,
factor
=
0.9
,
end_lr
=
0
,
**
kwargs
):
super
(
DecayLearningRate
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
epochs
=
epochs
+
1
self
.
factor
=
factor
self
.
end_lr
=
0
self
.
decay_steps
=
step_each_epoch
*
epochs
def
__call__
(
self
):
learning_rate
=
lr
.
PolynomialDecay
(
learning_rate
=
self
.
learning_rate
,
decay_steps
=
self
.
decay_steps
,
power
=
self
.
factor
,
end_lr
=
self
.
end_lr
)
return
learning_rate
tools/infer/predict_det.py
浏览文件 @
140a50df
...
...
@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
elif
self
.
det_algorithm
==
"DB++"
:
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_db_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_db_box_thresh
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
pre_process_list
[
1
]
=
{
'NormalizeImage'
:
{
'std'
:
[
1.0
,
1.0
,
1.0
],
'mean'
:
[
0.48109378172549
,
0.45752457890196
,
0.40787054090196
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
}
elif
self
.
det_algorithm
==
"EAST"
:
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_east_score_thresh
...
...
@@ -231,7 +248,7 @@ class TextDetector(object):
preds
[
'f_score'
]
=
outputs
[
1
]
preds
[
'f_tco'
]
=
outputs
[
2
]
preds
[
'f_tvo'
]
=
outputs
[
3
]
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
]:
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
,
'DB++'
]:
preds
[
'maps'
]
=
outputs
[
0
]
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
output
in
enumerate
(
outputs
):
...
...
tools/program.py
浏览文件 @
140a50df
...
...
@@ -577,7 +577,7 @@ def preprocess(is_train=False):
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'LayoutLMv2'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'ViTSTR'
,
'ABINet'
'SVTR'
,
'ViTSTR'
,
'ABINet'
,
'DB++'
]
if
use_xpu
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录