Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
88964dc9
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
88964dc9
编写于
3月 19, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add visual png
上级
97111112
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
296 addition
and
641 deletion
+296
-641
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+11
-7
doc/doc_ch/e2e.md
doc/doc_ch/e2e.md
+12
-6
doc/doc_ch/inference.md
doc/doc_ch/inference.md
+7
-9
doc/imgs_results/e2e_res_img623_pgnet.jpg
doc/imgs_results/e2e_res_img623_pgnet.jpg
+0
-0
doc/imgs_results/e2e_res_img_10_pgnet.jpg
doc/imgs_results/e2e_res_img_10_pgnet.jpg
+0
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+3
-3
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+18
-10
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+1
-0
ppocr/losses/e2e_pg_loss.py
ppocr/losses/e2e_pg_loss.py
+19
-93
ppocr/modeling/heads/e2e_pg_head.py
ppocr/modeling/heads/e2e_pg_head.py
+1
-2
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+7
-95
ppocr/utils/e2e_utils/extract_batchsize.py
ppocr/utils/e2e_utils/extract_batchsize.py
+87
-0
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+123
-361
ppocr/utils/pgnet_dict.txt
ppocr/utils/pgnet_dict.txt
+0
-36
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+5
-15
tools/infer/utility.py
tools/infer/utility.py
+2
-4
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
88964dc9
...
...
@@ -18,11 +18,13 @@ Global:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
valid_set
:
totaltext
#two mode: totaltext valid curved words, partvgg valid non-curved words
valid_set
:
totaltext
#
two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path
:
./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path
:
ppocr/utils/
pgnet
_dict.txt
character_dict_path
:
ppocr/utils/
ic15
_dict.txt
character_type
:
EN
max_text_length
:
50
max_text_length
:
50
# the max length in seq
max_text_nums
:
30
# the max seq nums in a pic
tcl_len
:
64
Architecture
:
model_type
:
e2e
...
...
@@ -33,13 +35,15 @@ Architecture:
layers
:
50
Neck
:
name
:
PGFPN
model_name
:
large
Head
:
name
:
PGHead
model_name
:
large
Loss
:
name
:
PGLoss
tcl_bs
:
64
max_text_length
:
50
# the same as Global: max_text_length
max_text_nums
:
30
# the same as Global:max_text_nums
pad_num
:
36
# the length of dict for pad
Optimizer
:
name
:
Adam
...
...
@@ -54,10 +58,10 @@ Optimizer:
PostProcess
:
name
:
PGPostProcess
score_thresh
:
0.
8
score_thresh
:
0.
5
Metric
:
name
:
E2EMetric
character_dict_path
:
ppocr/utils/
pgnet
_dict.txt
character_dict_path
:
ppocr/utils/
ic15
_dict.txt
main_indicator
:
f_score_e2e
Train
:
...
...
doc/doc_ch/e2e.md
浏览文件 @
88964dc9
...
...
@@ -9,8 +9,10 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是:
```
/PaddleOCR/train_data/part_vgg_synth/train/
└─ image/ partvgg数据集的训练数据
└─ train_annotation_info.txt partvgg数据集的测试标注
|- image/ partvgg数据集的训练数据
|- 119_nile_110_31.png
| ...
|- train_annotation_info.txt partvgg数据集的测试标注
```
提供的标注文件格式如下,中间用"
\t
"分隔:
...
...
@@ -18,7 +20,7 @@
" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注
119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why
```
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名
前缀
, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,
**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
...
...
@@ -26,8 +28,12 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
└─ rgb/ total_text数据集的训练数据
└─ poly/ total_text数据集的测试标注
|- rgb/ total_text数据集的训练数据
|- gt_0.png
| ...
|-poly/ total_text数据集的测试标注
|- gt_0.txt
| ...
```
提供的标注文件格式如下,中间用"
\t
"分隔:
...
...
@@ -36,7 +42,7 @@
1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST
1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972
```
标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
标注文件当中,其中每一个txt文件代表一组数据,文件名
就是
同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,
**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
...
...
doc/doc_ch/inference.md
浏览文件 @
88964dc9
...
...
@@ -29,7 +29,7 @@ inference 模型(`paddle.jit.save`保存的模型)
-
[
5. 多语言模型的推理
](
#多语言模型的推理
)
-
[
四、端到端模型推理
](
#端到端模型推理
)
-
[
1. PGNet端到端模型推理
](
#
SAST文本检测
模型推理
)
-
[
1. PGNet端到端模型推理
](
#
PGNet端到端
模型推理
)
-
[
五、方向分类模型推理
](
#方向识别模型推理
)
-
[
1. 方向分类模型推理
](
#方向分类模型推理
)
...
...
@@ -366,7 +366,7 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
<a
name=
"
SAST文本检测
模型推理"
></a>
<a
name=
"
PGNet端到端
模型推理"
></a>
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例(
[
模型下载地址
](
https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar
)
),可以使用如下命令进行转换:
...
...
@@ -375,28 +375,26 @@ python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrai
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**
,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e
_pgnet_ic15/"
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e
/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到
`./inference_results`
文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](
../imgs_results/
det_res_img_10_sas
t.jpg
)
![](
../imgs_results/
e2e_res_img_10_pgne
t.jpg
)
#### (2). 弯曲文本检测模型(Total-Text)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例(
[
模型下载地址
](
https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar
)
),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
_pgnet_tt
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**
可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e
_pgnet_tt
/" --e2e_pgnet_polygon=True
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到
`./inference_results`
文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](
../imgs_results/e2e_res_img623_pg.jpg
)
**注意**
:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
![](
../imgs_results/e2e_res_img623_pgnet.jpg
)
<a
name=
"方向分类模型推理"
></a>
...
...
doc/imgs_results/e2e_res_img623_pg.jpg
→
doc/imgs_results/e2e_res_img623_pg
net
.jpg
查看替换文件 @
97111112
浏览文件 @
88964dc9
135.6 KB
|
W:
|
H:
133.6 KB
|
W:
|
H:
2-up
Swipe
Onion skin
doc/imgs_results/e2e_res_img_10_pgnet.jpg
0 → 100644
浏览文件 @
88964dc9
337.2 KB
ppocr/data/imaug/label_ops.py
浏览文件 @
88964dc9
...
...
@@ -197,17 +197,17 @@ class E2ELabelEncode(BaseRecLabelEncode):
super
(
E2ELabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
text
=
text
.
upp
er
()
text
=
text
.
low
er
()
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
36
]
*
(
self
.
max_text_len
-
len
(
text
)
)
# use 36 to pad
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_texts
.
append
(
text
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
return
data
...
...
ppocr/data/imaug/pg_process.py
浏览文件 @
88964dc9
...
...
@@ -22,16 +22,23 @@ __all__ = ['PGProcessTrain']
class
PGProcessTrain
(
object
):
def
__init__
(
self
,
character_dict_path
,
max_text_length
,
max_text_nums
,
tcl_len
,
batch_size
=
14
,
min_crop_size
=
24
,
min_text_size
=
10
,
max_text_size
=
512
,
**
kwargs
):
self
.
tcl_len
=
tcl_len
self
.
max_text_length
=
max_text_length
self
.
max_text_nums
=
max_text_nums
self
.
batch_size
=
batch_size
self
.
min_crop_size
=
min_crop_size
self
.
min_text_size
=
min_text_size
self
.
max_text_size
=
max_text_size
self
.
Lexicon_Table
=
self
.
get_dict
(
character_dict_path
)
self
.
pad_num
=
len
(
self
.
Lexicon_Table
)
self
.
img_id
=
0
def
get_dict
(
self
,
character_dict_path
):
...
...
@@ -290,7 +297,7 @@ class PGProcessTrain(object):
height_list
.
append
(
quad_h
)
norm_width
=
max
(
sum
(
width_list
)
/
n_char
,
1.0
)
average_height
=
max
(
sum
(
height_list
)
/
len
(
height_list
),
1.0
)
k
=
1
for
quad
in
poly_quads
:
direct_vector_full
=
(
(
quad
[
1
]
+
quad
[
2
])
-
(
quad
[
0
]
+
quad
[
3
]))
/
2.0
...
...
@@ -302,6 +309,8 @@ class PGProcessTrain(object):
cv2
.
fillPoly
(
direction_map
,
quad
.
round
().
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
direction_label
)
cv2
.
imwrite
(
"output/{}.png"
.
format
(
k
),
direction_map
*
255.0
)
k
+=
1
return
direction_map
def
calculate_average_height
(
self
,
poly_quads
):
...
...
@@ -371,7 +380,6 @@ class PGProcessTrain(object):
continue
if
tag
:
# continue
cv2
.
fillPoly
(
training_mask
,
poly
.
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
0.15
)
else
:
...
...
@@ -577,7 +585,7 @@ class PGProcessTrain(object):
Prepare text lablel by given Lexicon_Table.
"""
if
len
(
Lexicon_Table
)
==
36
:
return
label_str
.
upp
er
()
return
label_str
.
low
er
()
else
:
return
label_str
...
...
@@ -846,23 +854,23 @@ class PGProcessTrain(object):
return
None
pos_list_temp
=
np
.
zeros
([
64
,
3
])
pos_mask_temp
=
np
.
zeros
([
64
,
1
])
label_list_temp
=
np
.
zeros
([
50
,
1
])
+
36
label_list_temp
=
np
.
zeros
([
self
.
max_text_length
,
1
])
+
self
.
pad_num
for
i
,
label
in
enumerate
(
label_list
):
n
=
len
(
label
)
if
n
>
50
:
label_list
[
i
]
=
label
[:
50
]
if
n
>
self
.
max_text_length
:
label_list
[
i
]
=
label
[:
self
.
max_text_length
]
continue
while
n
<
50
:
label
.
append
([
36
])
while
n
<
self
.
max_text_length
:
label
.
append
([
self
.
pad_num
])
n
+=
1
for
i
in
range
(
len
(
label_list
)):
label_list
[
i
]
=
np
.
array
(
label_list
[
i
])
if
len
(
pos_list
)
<=
0
or
len
(
pos_list
)
>
30
:
#一张图片中最多存在30行文本
if
len
(
pos_list
)
<=
0
or
len
(
pos_list
)
>
self
.
max_text_nums
:
return
None
for
__
in
range
(
30
-
len
(
pos_list
),
0
,
-
1
):
for
__
in
range
(
self
.
max_text_nums
-
len
(
pos_list
),
0
,
-
1
):
pos_list
.
append
(
pos_list_temp
)
pos_mask
.
append
(
pos_mask_temp
)
label_list
.
append
(
label_list_temp
)
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
88964dc9
...
...
@@ -156,6 +156,7 @@ class PGDataSet(Dataset):
img
=
f
.
read
()
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/losses/e2e_pg_loss.py
浏览文件 @
88964dc9
...
...
@@ -18,102 +18,26 @@ from __future__ import print_function
from
paddle
import
nn
import
paddle
import
numpy
as
np
import
copy
from
.det_basic_loss
import
DiceLoss
from
ppocr.utils.e2e_utils.extract_batchsize
import
*
class
PGLoss
(
nn
.
Layer
):
def
__init__
(
self
,
eps
=
1e-6
,
**
kwargs
):
def
__init__
(
self
,
tcl_bs
,
max_text_length
,
max_text_nums
,
pad_num
,
eps
=
1e-6
,
**
kwargs
):
super
(
PGLoss
,
self
).
__init__
()
self
.
tcl_bs
=
tcl_bs
self
.
max_text_nums
=
max_text_nums
self
.
max_text_length
=
max_text_length
self
.
pad_num
=
pad_num
self
.
dice_loss
=
DiceLoss
(
eps
=
eps
)
def
org_tcl_rois
(
self
,
batch_size
,
pos_lists
,
pos_masks
,
label_lists
):
"""
"""
pos_lists_
,
pos_masks_
,
label_lists_
=
[],
[],
[]
img_bs
=
batch_size
tcl_bs
=
64
ngpu
=
int
(
batch_size
/
img_bs
)
img_ids
=
np
.
array
(
pos_lists
,
dtype
=
np
.
int32
)[:,
0
,
0
].
copy
()
pos_lists_split
,
pos_masks_split
,
label_lists_split
=
[],
[],
[]
for
i
in
range
(
ngpu
):
pos_lists_split
.
append
([])
pos_masks_split
.
append
([])
label_lists_split
.
append
([])
for
i
in
range
(
img_ids
.
shape
[
0
]):
img_id
=
img_ids
[
i
]
gpu_id
=
int
(
img_id
/
img_bs
)
img_id
=
img_id
%
img_bs
pos_list
=
pos_lists
[
i
].
copy
()
pos_list
[:,
0
]
=
img_id
pos_lists_split
[
gpu_id
].
append
(
pos_list
)
pos_masks_split
[
gpu_id
].
append
(
pos_masks
[
i
].
copy
())
label_lists_split
[
gpu_id
].
append
(
copy
.
deepcopy
(
label_lists
[
i
]))
# repeat or delete
for
i
in
range
(
ngpu
):
vp_len
=
len
(
pos_lists_split
[
i
])
if
vp_len
<=
tcl_bs
:
for
j
in
range
(
0
,
tcl_bs
-
vp_len
):
pos_list
=
pos_lists_split
[
i
][
j
].
copy
()
pos_lists_split
[
i
].
append
(
pos_list
)
pos_mask
=
pos_masks_split
[
i
][
j
].
copy
()
pos_masks_split
[
i
].
append
(
pos_mask
)
label_list
=
copy
.
deepcopy
(
label_lists_split
[
i
][
j
])
label_lists_split
[
i
].
append
(
label_list
)
else
:
for
j
in
range
(
0
,
vp_len
-
tcl_bs
):
c_len
=
len
(
pos_lists_split
[
i
])
pop_id
=
np
.
random
.
permutation
(
c_len
)[
0
]
pos_lists_split
[
i
].
pop
(
pop_id
)
pos_masks_split
[
i
].
pop
(
pop_id
)
label_lists_split
[
i
].
pop
(
pop_id
)
# merge
for
i
in
range
(
ngpu
):
pos_lists_
.
extend
(
pos_lists_split
[
i
])
pos_masks_
.
extend
(
pos_masks_split
[
i
])
label_lists_
.
extend
(
label_lists_split
[
i
])
return
pos_lists_
,
pos_masks_
,
label_lists_
def
pre_process
(
self
,
label_list
,
pos_list
,
pos_mask
):
max_len
=
30
# the max texts in a single image
max_str_len
=
50
# the max len in a single text
pad_num
=
36
# padding num
label_list
=
label_list
.
numpy
()
batch
,
_
,
_
,
_
=
label_list
.
shape
pos_list
=
pos_list
.
numpy
()
pos_mask
=
pos_mask
.
numpy
()
pos_list_t
=
[]
pos_mask_t
=
[]
label_list_t
=
[]
for
i
in
range
(
batch
):
for
j
in
range
(
max_len
):
if
pos_mask
[
i
,
j
].
any
():
pos_list_t
.
append
(
pos_list
[
i
][
j
])
pos_mask_t
.
append
(
pos_mask
[
i
][
j
])
label_list_t
.
append
(
label_list
[
i
][
j
])
pos_list
,
pos_mask
,
label_list
=
self
.
org_tcl_rois
(
batch
,
pos_list_t
,
pos_mask_t
,
label_list_t
)
label
=
[]
tt
=
[
l
.
tolist
()
for
l
in
label_list
]
for
i
in
range
(
batch
):
k
=
0
for
j
in
range
(
max_str_len
):
if
tt
[
i
][
j
][
0
]
!=
pad_num
:
k
+=
1
else
:
break
label
.
append
(
k
)
label
=
paddle
.
to_tensor
(
label
)
label
=
paddle
.
cast
(
label
,
dtype
=
'int64'
)
pos_list
=
paddle
.
to_tensor
(
pos_list
)
pos_mask
=
paddle
.
to_tensor
(
pos_mask
)
label_list
=
paddle
.
squeeze
(
paddle
.
to_tensor
(
label_list
),
axis
=
2
)
label_list
=
paddle
.
cast
(
label_list
,
dtype
=
'int32'
)
return
pos_list
,
pos_mask
,
label_list
,
label
def
border_loss
(
self
,
f_border
,
l_border
,
l_score
,
l_mask
):
l_border_split
,
l_border_norm
=
paddle
.
tensor
.
split
(
l_border
,
num_or_sections
=
[
4
,
1
],
axis
=
1
)
...
...
@@ -183,7 +107,7 @@ class PGLoss(nn.Layer):
labels
=
tcl_label
,
input_lengths
=
input_lengths
,
label_lengths
=
label_t
,
blank
=
36
,
blank
=
self
.
pad_num
,
reduction
=
'none'
)
cost
=
cost
.
mean
()
return
cost
...
...
@@ -192,12 +116,14 @@ class PGLoss(nn.Layer):
images
,
tcl_maps
,
tcl_label_maps
,
border_maps
\
,
direction_maps
,
training_masks
,
label_list
,
pos_list
,
pos_mask
=
labels
# for all the batch_size
pos_list
,
pos_mask
,
label_list
,
label_t
=
self
.
pre_process
(
label_list
,
pos_list
,
pos_mask
)
pos_list
,
pos_mask
,
label_list
,
label_t
=
pre_process
(
label_list
,
pos_list
,
pos_mask
,
self
.
max_text_length
,
self
.
max_text_nums
,
self
.
pad_num
,
self
.
tcl_bs
)
f_score
,
f_boder
,
f_direction
,
f_char
=
predicts
f_score
,
f_border
,
f_direction
,
f_char
=
predicts
[
'f_score'
],
predicts
[
'f_border'
],
predicts
[
'f_direction'
],
\
predicts
[
'f_char'
]
score_loss
=
self
.
dice_loss
(
f_score
,
tcl_maps
,
training_masks
)
border_loss
=
self
.
border_loss
(
f_boder
,
border_maps
,
tcl_maps
,
border_loss
=
self
.
border_loss
(
f_bo
r
der
,
border_maps
,
tcl_maps
,
training_masks
)
direction_loss
=
self
.
direction_loss
(
f_direction
,
direction_maps
,
tcl_maps
,
training_masks
)
...
...
ppocr/modeling/heads/e2e_pg_head.py
浏览文件 @
88964dc9
...
...
@@ -66,9 +66,8 @@ class PGHead(nn.Layer):
"""
"""
def
__init__
(
self
,
in_channels
,
model_name
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
PGHead
,
self
).
__init__
()
self
.
model_name
=
model_name
self
.
conv_f_score1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
88964dc9
...
...
@@ -23,8 +23,7 @@ __dir__ = os.path.dirname(__file__)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
,
generate_pivot_list
,
restore_poly
import
paddle
...
...
@@ -34,16 +33,10 @@ class PGPostProcess(object):
"""
def
__init__
(
self
,
character_dict_path
,
valid_set
,
score_thresh
,
**
kwargs
):
self
.
Lexicon_Table
=
get_dict
(
character_dict_path
)
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
...
...
@@ -61,96 +54,15 @@ class PGPostProcess(object):
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
instance_yxs_list
,
seq_strs
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
np
.
expand_dims
(
p_char
,
axis
=
0
)
p_char
=
paddle
.
to_tensor
(
p_char
)
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
featyre_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
featyre_seq
=
np
.
expand_dims
(
featyre_seq
.
numpy
(),
axis
=
0
)
t
=
len
(
featyre_seq
[
0
])
featyre_seq
=
paddle
.
to_tensor
(
featyre_seq
)
l
=
np
.
array
([[
t
]]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
l
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred1
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
x
in
seq_pred1
[:
seq_len
]:
temp_t
.
append
(
x
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
self
.
Lexicon_Table
,
score_thresh
=
self
.
score_thresh
)
poly_list
,
keep_str_list
=
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
...
...
ppocr/utils/e2e_utils/extract_batchsize.py
0 → 100644
浏览文件 @
88964dc9
import
paddle
import
numpy
as
np
import
copy
def
org_tcl_rois
(
batch_size
,
pos_lists
,
pos_masks
,
label_lists
,
tcl_bs
):
"""
"""
pos_lists_
,
pos_masks_
,
label_lists_
=
[],
[],
[]
img_bs
=
batch_size
ngpu
=
int
(
batch_size
/
img_bs
)
img_ids
=
np
.
array
(
pos_lists
,
dtype
=
np
.
int32
)[:,
0
,
0
].
copy
()
pos_lists_split
,
pos_masks_split
,
label_lists_split
=
[],
[],
[]
for
i
in
range
(
ngpu
):
pos_lists_split
.
append
([])
pos_masks_split
.
append
([])
label_lists_split
.
append
([])
for
i
in
range
(
img_ids
.
shape
[
0
]):
img_id
=
img_ids
[
i
]
gpu_id
=
int
(
img_id
/
img_bs
)
img_id
=
img_id
%
img_bs
pos_list
=
pos_lists
[
i
].
copy
()
pos_list
[:,
0
]
=
img_id
pos_lists_split
[
gpu_id
].
append
(
pos_list
)
pos_masks_split
[
gpu_id
].
append
(
pos_masks
[
i
].
copy
())
label_lists_split
[
gpu_id
].
append
(
copy
.
deepcopy
(
label_lists
[
i
]))
# repeat or delete
for
i
in
range
(
ngpu
):
vp_len
=
len
(
pos_lists_split
[
i
])
if
vp_len
<=
tcl_bs
:
for
j
in
range
(
0
,
tcl_bs
-
vp_len
):
pos_list
=
pos_lists_split
[
i
][
j
].
copy
()
pos_lists_split
[
i
].
append
(
pos_list
)
pos_mask
=
pos_masks_split
[
i
][
j
].
copy
()
pos_masks_split
[
i
].
append
(
pos_mask
)
label_list
=
copy
.
deepcopy
(
label_lists_split
[
i
][
j
])
label_lists_split
[
i
].
append
(
label_list
)
else
:
for
j
in
range
(
0
,
vp_len
-
tcl_bs
):
c_len
=
len
(
pos_lists_split
[
i
])
pop_id
=
np
.
random
.
permutation
(
c_len
)[
0
]
pos_lists_split
[
i
].
pop
(
pop_id
)
pos_masks_split
[
i
].
pop
(
pop_id
)
label_lists_split
[
i
].
pop
(
pop_id
)
# merge
for
i
in
range
(
ngpu
):
pos_lists_
.
extend
(
pos_lists_split
[
i
])
pos_masks_
.
extend
(
pos_masks_split
[
i
])
label_lists_
.
extend
(
label_lists_split
[
i
])
return
pos_lists_
,
pos_masks_
,
label_lists_
def
pre_process
(
label_list
,
pos_list
,
pos_mask
,
max_text_length
,
max_text_nums
,
pad_num
,
tcl_bs
):
label_list
=
label_list
.
numpy
()
batch
,
_
,
_
,
_
=
label_list
.
shape
pos_list
=
pos_list
.
numpy
()
pos_mask
=
pos_mask
.
numpy
()
pos_list_t
=
[]
pos_mask_t
=
[]
label_list_t
=
[]
for
i
in
range
(
batch
):
for
j
in
range
(
max_text_nums
):
if
pos_mask
[
i
,
j
].
any
():
pos_list_t
.
append
(
pos_list
[
i
][
j
])
pos_mask_t
.
append
(
pos_mask
[
i
][
j
])
label_list_t
.
append
(
label_list
[
i
][
j
])
pos_list
,
pos_mask
,
label_list
=
org_tcl_rois
(
batch
,
pos_list_t
,
pos_mask_t
,
label_list_t
,
tcl_bs
)
label
=
[]
tt
=
[
l
.
tolist
()
for
l
in
label_list
]
for
i
in
range
(
tcl_bs
):
k
=
0
for
j
in
range
(
max_text_length
):
if
tt
[
i
][
j
][
0
]
!=
pad_num
:
k
+=
1
else
:
break
label
.
append
(
k
)
label
=
paddle
.
to_tensor
(
label
)
label
=
paddle
.
cast
(
label
,
dtype
=
'int64'
)
pos_list
=
paddle
.
to_tensor
(
pos_list
)
pos_mask
=
paddle
.
to_tensor
(
pos_mask
)
label_list
=
paddle
.
squeeze
(
paddle
.
to_tensor
(
label_list
),
axis
=
2
)
label_list
=
paddle
.
cast
(
label_list
,
dtype
=
'int32'
)
return
pos_list
,
pos_mask
,
label_list
,
label
ppocr/utils/e2e_utils/extract_textpoint.py
浏览文件 @
88964dc9
...
...
@@ -17,11 +17,9 @@ from __future__ import division
from
__future__
import
print_function
import
cv2
import
math
import
numpy
as
np
from
itertools
import
groupby
from
skimage.morphology._skeletonize
import
thin
from
cv2.ximgproc
import
thinning
as
thin
def
get_dict
(
character_dict_path
):
...
...
@@ -35,87 +33,39 @@ def get_dict(character_dict_path):
return
dict_character
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
4
):
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
# n x 96
probs_seq
=
softmax
(
logits_seq
)
dst_str
,
keep_idx_list
=
ctc_greedy_decoder
(
probs_seq
,
blank
=
C
-
1
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
probs_seq
=
logits_seq
labels
=
np
.
argmax
(
probs_seq
,
axis
=
1
)
dst_str
=
[
k
for
k
,
v_
in
groupby
(
labels
)
if
k
!=
C
-
1
]
detal
=
len
(
gather_info
)
//
(
pts_num
-
1
)
keep_idx_list
=
[
0
]
+
[
detal
*
(
i
+
1
)
for
i
in
range
(
pts_num
-
2
)]
+
[
-
1
]
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
keep_blank_in_idxs
=
True
):
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
Lexicon_Table
,
pts_num
=
6
):
"""
CTC decoder using multiple processes.
"""
decoder_results
=
[]
decoder_str
=
[]
decoder_xys
=
[]
for
gather_info
in
gather_info_list
:
res
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
decoder_results
.
append
(
res
)
return
decoder_results
if
len
(
gather_info
)
<
pts_num
:
continue
dst_str
,
xys_list
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
pts_num
)
dst_str_readable
=
''
.
join
([
Lexicon_Table
[
idx
]
for
idx
in
dst_str
])
if
len
(
dst_str_readable
)
<
2
:
continue
decoder_str
.
append
(
dst_str_readable
)
decoder_xys
.
append
(
xys_list
)
return
decoder_str
,
decoder_xys
def
sort_with_direction
(
pos_list
,
f_direction
):
...
...
@@ -157,57 +107,6 @@ def sort_with_direction(pos_list, f_direction):
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
f_direction: h x w x 2
...
...
@@ -260,262 +159,125 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return
all_list
def
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_expand
=
True
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
if
len
(
pos_list
)
<
3
:
continue
if
is_expand
:
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
else
:
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map_bi
=
(
p_score
>
score_thresh
)
*
1.0
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
p_tcl_map_bi
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
if
len
(
pos_list
)
<
5
:
continue
main_direction
=
extract_main_direction
(
pos_list
,
f_direction
)
# y x
reference_directin
=
np
.
array
([
0
,
1
]).
reshape
([
-
1
,
2
])
# y x
is_h_angle
=
abs
(
np
.
sum
(
main_direction
*
reference_directin
))
<
math
.
cos
(
math
.
pi
/
180
*
70
)
point_yxs
=
np
.
array
(
pos_list
)
max_y
,
max_x
=
np
.
max
(
point_yxs
,
axis
=
0
)
min_y
,
min_x
=
np
.
min
(
point_yxs
,
axis
=
0
)
is_h_len
=
(
max_y
-
min_y
)
<
1.5
*
(
max_x
-
min_x
)
pos_list_final
=
[]
if
is_h_len
:
xs
=
np
.
unique
(
xs
)
for
x
in
xs
:
ys
=
instance_label_map
[:,
x
].
copy
().
reshape
((
-
1
,
))
y
=
int
(
np
.
where
(
ys
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
else
:
ys
=
np
.
unique
(
ys
)
for
y
in
ys
:
xs
=
instance_label_map
[
y
,
:].
copy
().
reshape
((
-
1
,
))
x
=
int
(
np
.
where
(
xs
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list_final
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
def
point_pair2poly
(
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
valid_set
):
poly_list
=
[]
keep_str_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
offset_expand
=
1.0
if
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
*
offset_expand
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
detected_poly
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
keep_str_list
.
append
(
keep_str
)
if
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
print
(
'--> Not supported format.'
)
exit
(
-
1
)
return
poly_list
,
keep_str_list
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
if
is_curved
:
return
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_expand
=
True
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
else
:
return
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
def
generate_pivot_list_tt_inference
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
Lexicon_Table
,
score_thresh
=
0.5
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
ret
,
p_tcl_map
=
cv2
.
threshold
(
p_score
,
score_thresh
,
255
,
cv2
.
THRESH_BINARY
)
skeleton_map
=
thin
(
p_tcl_map
.
astype
(
'uint8'
))
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
)
,
connectivity
=
8
)
skeleton_map
,
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
pos_list_sorted_with_id
=
add_id
(
pos_list_sorted
,
image_id
=
image_id
)
all_pos_yxs
.
append
(
pos_list_sorted_with_id
)
return
all_pos_yxs
all_pos_yxs
.
append
(
pos_list_sorted
)
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decoded_str
,
keep_yxs_list
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
Lexicon_Table
=
Lexicon_Table
)
return
keep_yxs_list
,
decoded_str
ppocr/utils/pgnet_dict.txt
已删除
100644 → 0
浏览文件 @
97111112
0
1
2
3
4
5
6
7
8
9
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
\ No newline at end of file
tools/infer/predict_e2e.py
浏览文件 @
88964dc9
...
...
@@ -39,10 +39,7 @@ class TextE2e(object):
self
.
args
=
args
self
.
e2e_algorithm
=
args
.
e2e_algorithm
pre_process_list
=
[{
'E2EResizeForTest'
:
{
'max_side_len'
:
768
,
'valid_set'
:
'totaltext'
}
'E2EResizeForTest'
:
{}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
...
...
@@ -70,12 +67,6 @@ class TextE2e(object):
postprocess_params
[
"character_dict_path"
]
=
args
.
e2e_char_dict_path
postprocess_params
[
"valid_set"
]
=
args
.
e2e_pgnet_valid_set
self
.
e2e_pgnet_polygon
=
args
.
e2e_pgnet_polygon
if
self
.
e2e_pgnet_polygon
:
postprocess_params
[
"expand_scale"
]
=
1.2
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.2
else
:
postprocess_params
[
"expand_scale"
]
=
1.0
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.3
else
:
logger
.
info
(
"unknown e2e_algorithm:{}"
.
format
(
self
.
e2e_algorithm
))
sys
.
exit
(
0
)
...
...
@@ -102,6 +93,7 @@ class TextE2e(object):
return
dt_boxes
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
...
...
@@ -109,7 +101,6 @@ class TextE2e(object):
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
print
(
img
.
shape
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
...
...
@@ -123,13 +114,12 @@ class TextE2e(object):
preds
=
{}
if
self
.
e2e_algorithm
==
'PGNet'
:
preds
[
'f_
score
'
]
=
outputs
[
0
]
preds
[
'f_
borde
r'
]
=
outputs
[
1
]
preds
[
'f_
border
'
]
=
outputs
[
0
]
preds
[
'f_
cha
r'
]
=
outputs
[
1
]
preds
[
'f_direction'
]
=
outputs
[
2
]
preds
[
'f_
char
'
]
=
outputs
[
3
]
preds
[
'f_
score
'
]
=
outputs
[
3
]
else
:
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'strs'
]
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
...
...
tools/infer/utility.py
浏览文件 @
88964dc9
...
...
@@ -83,11 +83,9 @@ def parse_args():
# PGNet parmas
parser
.
add_argument
(
"--e2e_pgnet_score_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--e2e_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/pgnet_dict.txt"
)
"--e2e_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/ic15_dict.txt"
)
parser
.
add_argument
(
"--e2e_pgnet_valid_set"
,
type
=
str
,
default
=
'totaltext'
)
parser
.
add_argument
(
"--e2e_pgnet_polygon"
,
type
=
bool
,
default
=
Fals
e
)
parser
.
add_argument
(
"--e2e_pgnet_polygon"
,
type
=
bool
,
default
=
Tru
e
)
# params for text classifier
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录