Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
349b7d38
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
349b7d38
编写于
11月 16, 2021
作者:
天涯古巷
提交者:
GitHub
11月 16, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'PaddlePaddle:release/2.3' into release/2.3
上级
afe8ed19
31b06a2f
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
169 addition
and
70 deletion
+169
-70
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+1
-0
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
+1
-1
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
+0
-0
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml
+0
-0
configs/det/det_r50_vd_sast_icdar15.yml
configs/det/det_r50_vd_sast_icdar15.yml
+2
-2
configs/det/det_r50_vd_sast_totaltext.yml
configs/det/det_r50_vd_sast_totaltext.yml
+2
-2
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+9
-8
deploy/hubserving/ocr_det/module.py
deploy/hubserving/ocr_det/module.py
+1
-1
doc/doc_ch/algorithm_overview.md
doc/doc_ch/algorithm_overview.md
+1
-2
doc/doc_ch/reference.md
doc/doc_ch/reference.md
+10
-0
ppocr/data/imaug/east_process.py
ppocr/data/imaug/east_process.py
+10
-12
ppocr/data/imaug/iaa_augment.py
ppocr/data/imaug/iaa_augment.py
+5
-0
ppocr/data/imaug/make_border_map.py
ppocr/data/imaug/make_border_map.py
+17
-2
ppocr/data/imaug/make_shrink_map.py
ppocr/data/imaug/make_shrink_map.py
+17
-2
ppocr/data/imaug/random_crop_data.py
ppocr/data/imaug/random_crop_data.py
+17
-2
ppocr/data/imaug/sast_process.py
ppocr/data/imaug/sast_process.py
+4
-1
ppocr/data/imaug/text_image_aug/augment.py
ppocr/data/imaug/text_image_aug/augment.py
+4
-1
ppocr/data/imaug/text_image_aug/warp_mls.py
ppocr/data/imaug/text_image_aug/warp_mls.py
+4
-1
ppocr/losses/det_basic_loss.py
ppocr/losses/det_basic_loss.py
+4
-1
ppocr/losses/det_db_loss.py
ppocr/losses/det_db_loss.py
+4
-0
ppocr/modeling/backbones/rec_mv1_enhance.py
ppocr/modeling/backbones/rec_mv1_enhance.py
+2
-0
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+1
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+24
-16
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+6
-1
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+13
-8
ppocr/postprocess/locality_aware_nms.py
ppocr/postprocess/locality_aware_nms.py
+1
-0
ppocr/utils/logging.py
ppocr/utils/logging.py
+4
-1
tools/infer/utility.py
tools/infer/utility.py
+5
-5
未找到文件。
configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml
→
configs/det/ch_PP-OCRv2/ch_PP-OCR
v2
_det_cml.yml
浏览文件 @
349b7d38
...
...
@@ -141,6 +141,7 @@ Train:
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
CopyPaste
:
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
...
...
configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml
→
configs/det/ch_PP-OCRv2/ch_PP-OCR
v2
_det_distill.yml
浏览文件 @
349b7d38
...
...
@@ -91,7 +91,7 @@ Optimizer:
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
]
model_name
:
[
"
Student"
]
key
:
head_out
thresh
:
0.3
box_thresh
:
0.6
...
...
configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml
→
configs/det/ch_PP-OCRv2/ch_PP-OCR
v2
_det_dml.yml
浏览文件 @
349b7d38
文件已移动
configs/det/ch_PP-OCRv2/ch_PP-OCR_det_student.yml
→
configs/det/ch_PP-OCRv2/ch_PP-OCR
v2
_det_student.yml
浏览文件 @
349b7d38
文件已移动
configs/det/det_r50_vd_sast_icdar15.yml
浏览文件 @
349b7d38
...
...
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
4000
,
5000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/ResNet50_vd_ssld_pretrained
/
pretrained_model
:
./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
...
...
@@ -106,4 +106,4 @@ Eval:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
\ No newline at end of file
num_workers
:
2
configs/det/det_r50_vd_sast_totaltext.yml
浏览文件 @
349b7d38
...
...
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
4000
,
5000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/ResNet50_vd_ssld_pretrained
/
pretrained_model
:
./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
...
...
@@ -105,4 +105,4 @@ Eval:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
\ No newline at end of file
num_workers
:
2
configs/table/table_mv3.yml
浏览文件 @
349b7d38
Global
:
use_gpu
:
true
epoch_num
:
5
0
epoch_num
:
40
0
log_smooth_window
:
20
print_batch_step
:
5
save_model_dir
:
./output/table_mv3/
save_epoch_step
:
5
save_epoch_step
:
3
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step
:
[
0
,
400
]
cal_metric_during_train
:
True
pretrained_model
:
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/
imgs_words/ch/word_1
.jpg
infer_img
:
doc/
table/table
.jpg
# for data or label process
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
max_text_length
:
100
max_elem_length
:
5
00
max_elem_length
:
8
00
max_cell_num
:
500
infer_mode
:
False
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
name
:
Adam
beta1
:
0.9
...
...
@@ -41,13 +40,15 @@ Architecture:
Backbone
:
name
:
MobileNetV3
scale
:
1.0
model_name
:
small
disable_se
:
True
model_name
:
large
Head
:
name
:
TableAttentionHead
hidden_size
:
256
l2_decay
:
0.00001
loc_type
:
2
max_text_length
:
100
max_elem_length
:
800
max_cell_num
:
500
Loss
:
name
:
TableAttentionLoss
...
...
deploy/hubserving/ocr_det/module.py
浏览文件 @
349b7d38
...
...
@@ -18,7 +18,7 @@ import paddlehub as hub
from
tools.infer.utility
import
base64_to_cv2
from
tools.infer.predict_det
import
TextDetector
from
tools.infer.utility
import
parse_args
from
deploy.hubserving.ocr_
system
.params
import
read_params
from
deploy.hubserving.ocr_
det
.params
import
read_params
@
moduleinfo
(
...
...
doc/doc_ch/algorithm_overview.md
浏览文件 @
349b7d38
...
...
@@ -50,7 +50,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
-
[
x] STAR-Net([paper
](
http://www.bmva.org/bmvc/2016/papers/paper043/index.html
)
)[11]
-
[
x] RARE([paper
](
https://arxiv.org/abs/1603.03915v1
)
)[12]
-
[
x] SRN([paper
](
https://arxiv.org/abs/2003.12294
)
)[5]
-
[
x] NRTR([paper
](
https://arxiv.org/abs/1806.00926v2
)
)
-
[
x] NRTR([paper
](
https://arxiv.org/abs/1806.00926v2
)
)
[13]
参考
[
DTRB
][
3
]
(https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...
...
@@ -78,4 +78,3 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
## 3. 模型推理
上述模型中除PP-OCR系列模型以外,其余模型仅支持基于Python引擎的推理,具体内容可参考
[
基于Python预测引擎推理
](
./inference.md
)
doc/doc_ch/reference.md
浏览文件 @
349b7d38
...
...
@@ -112,4 +112,14 @@
year={2016}
}
13.NRTR
@misc{sheng2019nrtr,
title={NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
author={Fenfen Sheng and Zhineng Chen and Bo Xu},
year={2019},
eprint={1806.00926},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
ppocr/data/imaug/east_process.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import
math
import
cv2
import
numpy
as
np
...
...
@@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain']
class
EASTProcessTrain
(
object
):
def
__init__
(
self
,
image_shape
=
[
512
,
512
],
background_ratio
=
0.125
,
min_crop_side_ratio
=
0.1
,
min_text_size
=
10
,
image_shape
=
[
512
,
512
],
background_ratio
=
0.125
,
min_crop_side_ratio
=
0.1
,
min_text_size
=
10
,
**
kwargs
):
self
.
input_size
=
image_shape
[
1
]
self
.
random_scale
=
np
.
array
([
0.5
,
1
,
2.0
,
3.0
])
...
...
@@ -282,12 +285,7 @@ class EASTProcessTrain(object):
1.0
/
max
(
min
(
poly_h
,
poly_w
),
1.0
)
return
score_map
,
geo_map
,
training_mask
def
crop_area
(
self
,
im
,
polys
,
tags
,
crop_background
=
False
,
max_tries
=
50
):
def
crop_area
(
self
,
im
,
polys
,
tags
,
crop_background
=
False
,
max_tries
=
50
):
"""
make random crop from the input image
:param im:
...
...
@@ -436,4 +434,4 @@ class EASTProcessTrain(object):
data
[
'geo_map'
]
=
geo_map
data
[
'training_mask'
]
=
training_mask
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
return
data
\ No newline at end of file
return
data
ppocr/data/imaug/iaa_augment.py
浏览文件 @
349b7d38
...
...
@@ -11,6 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/data/imaug/make_border_map.py
浏览文件 @
349b7d38
# -*- coding:utf-8 -*-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/data/imaug/make_shrink_map.py
浏览文件 @
349b7d38
# -*- coding:utf-8 -*-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/data/imaug/random_crop_data.py
浏览文件 @
349b7d38
# -*- coding:utf-8 -*-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/data/imaug/sast_process.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This part code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import
math
import
cv2
import
numpy
as
np
...
...
ppocr/data/imaug/text_image_aug/augment.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
"""
import
numpy
as
np
from
.warp_mls
import
WarpMLS
...
...
ppocr/data/imaug/text_image_aug/warp_mls.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
"""
import
numpy
as
np
...
...
ppocr/losses/det_basic_loss.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/losses/det_db_loss.py
浏览文件 @
349b7d38
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
ppocr/modeling/backbones/rec_mv1_enhance.py
浏览文件 @
349b7d38
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/modeling/heads/rec_att_head.py
浏览文件 @
349b7d38
...
...
@@ -75,7 +75,7 @@ class AttentionHead(nn.Layer):
probs_step
,
axis
=
1
)],
axis
=
1
)
next_input
=
probs_step
.
argmax
(
axis
=
1
)
targets
=
next_input
probs
=
paddle
.
nn
.
functional
.
softmax
(
probs
,
axis
=
2
)
return
probs
...
...
ppocr/modeling/heads/table_att_head.py
浏览文件 @
349b7d38
...
...
@@ -23,32 +23,40 @@ import numpy as np
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
max_text_length
=
100
,
max_elem_length
=
800
,
max_cell_num
=
500
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
elem_num
=
30
self
.
max_text_length
=
100
self
.
max_elem_length
=
500
self
.
max_cell_num
=
500
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
structure_generator
=
nn
.
Linear
(
hidden_size
,
self
.
elem_num
)
self
.
loc_type
=
loc_type
self
.
in_max_len
=
in_max_len
if
self
.
loc_type
==
1
:
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
return
input_ont_hot
...
...
@@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
if
len
(
fea
.
shape
)
==
3
:
pass
else
:
last_shape
=
int
(
np
.
prod
(
fea
.
shape
[
2
:]))
# gry added
last_shape
=
int
(
np
.
prod
(
fea
.
shape
[
2
:]))
# gry added
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
batch_size
=
fea
.
shape
[
0
]
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
if
self
.
training
and
targets
is
not
None
:
structure
=
targets
[
0
]
for
i
in
range
(
self
.
max_elem_length
+
1
):
for
i
in
range
(
self
.
max_elem_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
...
@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
alpha
=
None
max_elem_length
=
paddle
.
to_tensor
(
self
.
max_elem_length
)
i
=
0
while
i
<
max_elem_length
+
1
:
while
i
<
max_elem_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
temp_elem
,
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
...
@@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
structure_probs_step
=
self
.
structure_generator
(
outputs
)
temp_elem
=
structure_probs_step
.
argmax
(
axis
=
1
,
dtype
=
"int32"
)
i
+=
1
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
structure_probs
=
self
.
structure_generator
(
output
)
structure_probs
=
F
.
softmax
(
structure_probs
)
...
...
@@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
loc_concat
=
paddle
.
concat
([
output
,
loc_fea
],
axis
=
2
)
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
...
...
ppocr/modeling/transforms/tps.py
浏览文件 @
349b7d38
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -231,7 +235,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F
=
self
.
F
hat_eye
=
paddle
.
eye
(
F
,
dtype
=
'float64'
)
# F x F
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
(
hat_C
**
2
)
*
paddle
.
log
(
hat_C
)
delta_C
=
paddle
.
concat
(
# F+3 x F+3
[
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -190,7 +193,8 @@ class DBPostProcess(object):
class
DistillationDBPostProcess
(
object
):
def
__init__
(
self
,
model_name
=
[
"student"
],
def
__init__
(
self
,
model_name
=
[
"student"
],
key
=
None
,
thresh
=
0.3
,
box_thresh
=
0.6
,
...
...
@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
**
kwargs
):
self
.
model_name
=
model_name
self
.
key
=
key
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
...
...
ppocr/postprocess/locality_aware_nms.py
浏览文件 @
349b7d38
"""
Locality aware nms.
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
"""
import
numpy
as
np
...
...
ppocr/utils/logging.py
浏览文件 @
349b7d38
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py
"""
import
os
import
sys
import
logging
...
...
tools/infer/utility.py
浏览文件 @
349b7d38
...
...
@@ -187,7 +187,7 @@ def create_predictor(args, mode, logger):
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
2
,
2
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
2000
,
200
0
],
"x"
:
[
1
,
3
,
1280
,
128
0
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
200
,
200
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
...
...
@@ -237,16 +237,16 @@ def create_predictor(args, mode, logger):
opt_input_shape
.
update
(
opt_pact_shape
)
elif
mode
==
"rec"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
32
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
2000
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
1024
]}
opt_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
320
]}
elif
mode
==
"cls"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
48
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
48
,
2000
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
48
,
1024
]}
opt_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
48
,
320
]}
else
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
10
,
10
]}
max_input_shape
=
{
"x"
:
[
1
,
3
,
1000
,
1000
]}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
500
,
500
]}
max_input_shape
=
{
"x"
:
[
1
,
3
,
512
,
512
]}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
256
,
256
]}
config
.
set_trt_dynamic_shape_info
(
min_input_shape
,
max_input_shape
,
opt_input_shape
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录