Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
a0d1f923
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1533
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0d1f923
编写于
4月 09, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add different post process
上级
0cd48c35
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
516 addition
and
434 deletion
+516
-434
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+2
-2
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+9
-7
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+12
-4
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+12
-2
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+93
-8
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+16
-286
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+371
-122
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
train_data/total_text/train/poly/2.txt
train_data/total_text/train/poly/2.txt
+0
-2
train_data/total_text/train/rgb/2.jpg
train_data/total_text/train/rgb/2.jpg
+0
-0
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
a0d1f923
...
...
@@ -11,7 +11,7 @@ Global:
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights
:
Tru
e
load_static_weights
:
Fals
e
cal_metric_during_train
:
False
pretrained_model
:
checkpoints
:
...
...
@@ -94,7 +94,7 @@ Eval:
label_file_list
:
[
./train_data/total_text/test/
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
RGB
channel_first
:
False
-
E2ELabelEncode
:
-
E2EResizeForTest
:
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
a0d1f923
...
...
@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
text_label_index_list
,
temp_text
=
[],
[]
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
text
=
text
.
lower
()
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_texts
.
append
(
text
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
temp_text
=
[]
for
c_
in
text
:
if
c_
in
self
.
dict
:
temp_text
.
append
(
self
.
dict
[
c_
])
temp_text
=
temp_text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
temp_text
))
text_label_index_list
.
append
(
temp_text
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
return
data
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
a0d1f923
...
...
@@ -24,6 +24,7 @@ class PGDataSet(Dataset):
self
.
logger
=
logger
self
.
seed
=
seed
self
.
mode
=
mode
global_config
=
config
[
'Global'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
...
...
@@ -62,10 +63,13 @@ class PGDataSet(Dataset):
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
map
(
float
,
poly_str
.
split
(
','
))
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
if
self
.
mode
.
lower
()
==
"eval"
:
while
len
(
poly
)
<
100
:
poly
.
append
(
-
1
)
text_polys
.
append
(
np
.
array
(
list
(
poly
)
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
...
...
@@ -135,8 +139,12 @@ class PGDataSet(Dataset):
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
if
self
.
mode
.
lower
()
==
"eval"
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly_gt'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
else
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
a0d1f923
...
...
@@ -33,10 +33,20 @@ class E2EMetric(object):
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
gt_polyons_batch
=
batch
[
2
]
temp_
gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
...
...
@@ -46,7 +56,7 @@ class E2EMetric(object):
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
gt_polyons_batch
,
[
gt_strs_batch
],
ignore_tags_batch
):
[
preds
],
[
gt_polyons_batch
]
,
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
a0d1f923
...
...
@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
,
generate_pivot_list
,
restore_poly
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
...
...
@@ -37,6 +38,11 @@ class PGPostProcess(object):
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
...
...
@@ -52,17 +58,96 @@ class PGPostProcess(object):
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
instance_yxs_list
,
seq_strs
=
generate_pivot_list
(
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
self
.
Lexicon_Table
,
score_thresh
=
self
.
score_thresh
)
poly_list
,
keep_str_list
=
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
self
.
valid_set
)
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
a0d1f923
...
...
@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
.
tolist
()
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
...
...
@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
...
...
@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
global_tp
=
0
global_fp
=
0
global_fn
=
0
global_sigma
=
[]
global_tau
=
[]
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_pred_str
=
[]
global_gt_str
=
[]
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for
input_id
in
range
(
allInputs
):
...
...
@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_tau
.
append
(
local_tau_table
)
global_pred_str
.
append
(
local_pred_str
)
global_gt_str
.
append
(
local_gt_str
)
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
global_sigma
=
local_sigma_table
global_tau
=
local_tau_table
global_pred_str
=
local_pred_str
global_gt_str
=
local_gt_str
single_data
=
{}
for
idx
in
range
(
len
(
global_sigma
)):
local_sigma_table
=
global_sigma
[
idx
]
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
# fid = open(fid_path, 'a+')
try
:
local_precision
=
local_accumulative_precision
/
num_det
except
ZeroDivisionError
:
local_precision
=
0
try
:
local_recall
=
local_accumulative_recall
/
num_gt
except
ZeroDivisionError
:
local_recall
=
0
try
:
local_f_score
=
2
*
local_precision
*
local_recall
/
(
local_precision
+
local_recall
)
except
ZeroDivisionError
:
local_f_score
=
0
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
single_data
[
"recall"
]
=
local_recall
single_data
[
'precision'
]
=
local_precision
single_data
[
'f_score'
]
=
local_f_score
return
single_data
...
...
@@ -435,10 +163,10 @@ def combine_results(all_data):
global_pred_str
=
[]
global_gt_str
=
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
]
[
0
]
)
global_tau
.
append
(
data
[
'global_tau'
]
[
0
]
)
global_pred_str
.
append
(
data
[
'global_pred_str'
]
[
0
]
)
global_gt_str
.
append
(
data
[
'global_gt_str'
]
[
0
]
)
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
...
...
@@ -676,6 +404,8 @@ def combine_results(all_data):
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
...
...
ppocr/utils/e2e_utils/extract_textpoint.py
浏览文件 @
a0d1f923
此差异已折叠。
点击以展开。
tools/infer/predict_e2e.py
浏览文件 @
a0d1f923
...
...
@@ -151,7 +151,7 @@ if __name__ == "__main__":
src_im
=
utility
.
draw_e2e_res
(
points
,
strs
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
"e2e_res_{}
_pgnet
"
.
format
(
img_name_pure
))
"e2e_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
if
count
>
1
:
...
...
train_data/total_text/train/poly/2.txt
已删除
100644 → 0
浏览文件 @
0cd48c35
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
train_data/total_text/train/rgb/2.jpg
已删除
100644 → 0
浏览文件 @
0cd48c35
40.9 KB
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录