Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ac4cef10
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ac4cef10
编写于
4月 10, 2021
作者:
D
Double_V
提交者:
GitHub
4月 10, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2437 from JetHong/pgnet-readme
fix eval score
上级
46ad3c47
f67e6e13
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
28 addition
and
79 deletion
+28
-79
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+2
-1
doc/doc_ch/pgnet.md
doc/doc_ch/pgnet.md
+5
-3
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+5
-10
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+9
-38
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+1
-0
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+6
-27
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
ac4cef10
...
...
@@ -61,6 +61,7 @@ PostProcess:
score_thresh
:
0.5
Metric
:
name
:
E2EMetric
gt_mat_dir
:
# the dir of gt_mat
character_dict_path
:
ppocr/utils/ic15_dict.txt
main_indicator
:
f_score_e2e
...
...
@@ -106,7 +107,7 @@ Eval:
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
strs'
,
'
tags'
]
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
strs'
,
'
tags'
,
'
img_id'
]
loader
:
shuffle
:
False
drop_last
:
False
...
...
doc/doc_ch/pgnet.md
浏览文件 @
ac4cef10
...
...
@@ -2,7 +2,7 @@
-
[
一、简介
](
#简介
)
-
[
二、环境配置
](
#环境配置
)
-
[
三、快速使用
](
#快速使用
)
-
[
四、模型训练、评估、推理
](
#
快速训练
)
-
[
四、模型训练、评估、推理
](
#
模型训练、评估、推理
)
<a
name=
"简介"
></a>
## 一、简介
...
...
@@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
![](
../pgnet_framework.png
)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
其检测识别效果图如下:
![](
../imgs_results/e2e_res_img293_pgnet.png
)
![](
../imgs_results/e2e_res_img295_pgnet.png
)
...
...
@@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](
../imgs_results/e2e_res_img623_pgnet.jpg
)
<a
name=
"
快速训练
"
></a>
<a
name=
"
模型训练、评估、推理
"
></a>
## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据
下载解压
[
totaltext
](
https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md
)
数据集到PaddleOCR/train_data/目录,数据集组织结构:
下载解压
[
totaltext
](
https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md
)
数据集到PaddleOCR/train_data/目录,数据集组织结构:
```
/PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
ac4cef10
...
...
@@ -64,9 +64,6 @@ class PGDataSet(Dataset):
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
if
self
.
mode
.
lower
()
==
"eval"
:
while
len
(
poly
)
<
100
:
poly
.
append
(
-
1
)
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
...
...
@@ -139,23 +136,21 @@ class PGDataSet(Dataset):
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
if
self
.
mode
.
lower
()
==
"eval"
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly_gt'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
else
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data_line
,
image_dir
)
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
'strs'
:
text_strs
,
'img_id'
:
img_id
}
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
ac4cef10
...
...
@@ -24,53 +24,24 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict
class
E2EMetric
(
object
):
def
__init__
(
self
,
gt_mat_dir
,
character_dict_path
,
main_indicator
=
'f_score_e2e'
,
**
kwargs
):
self
.
gt_mat_dir
=
gt_mat_dir
self
.
label_list
=
get_dict
(
character_dict_path
)
self
.
max_index
=
len
(
self
.
label_list
)
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
temp_gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
for
index
in
temp_list
:
if
index
<
self
.
max_index
:
t
+=
self
.
label_list
[
index
]
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
[
gt_polyons_batch
],
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
'text'
:
gt_str
,
'ignore'
:
ignore_tag
}
for
gt_polyon
,
gt_str
,
ignore_tag
in
zip
(
gt_polyons
,
gt_strs
,
ignore_tags
)]
# prepare det
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
pred
[
'points'
],
pred
[
'strs'
])]
result
=
get_socre
(
gt_info_list
,
e2e_info_list
)
self
.
results
.
append
(
result
)
img_id
=
batch
[
5
][
0
]
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'strs'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
metircs
=
combine_results
(
self
.
results
)
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
ac4cef10
...
...
@@ -138,6 +138,7 @@ class PGPostProcess(object):
continue
keep_str_list
.
append
(
keep_str
)
detected_poly
=
np
.
round
(
detected_poly
).
astype
(
'int32'
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
ac4cef10
...
...
@@ -13,10 +13,11 @@
# limitations under the License.
import
numpy
as
np
import
scipy.io
as
io
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
def
get_socre
(
gt_di
ct
,
pred_dict
):
def
get_socre
(
gt_di
r
,
img_id
,
pred_dict
):
allInputs
=
1
def
input_reading_mod
(
pred_dict
):
...
...
@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
):
"""This helper reads groundtruths from mat files"""
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
def
gt_reading_mod
(
gt_dir
,
gt_id
):
gt
=
io
.
loadmat
(
'%s/poly_gt_img%s.mat'
%
(
gt_dir
,
gt_id
))
gt
=
gt
[
'polygt'
]
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
...
...
@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
detections
=
input_reading_mod
(
pred_dict
)
groundtruths
=
gt_reading_mod
(
gt_di
ct
)
groundtruths
=
gt_reading_mod
(
gt_di
r
,
img_id
).
tolist
(
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录