Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
a0d1f923
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0d1f923
编写于
4月 09, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add different post process
上级
0cd48c35
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
516 addition
and
434 deletion
+516
-434
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+2
-2
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+9
-7
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+12
-4
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+12
-2
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+93
-8
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+16
-286
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+371
-122
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
train_data/total_text/train/poly/2.txt
train_data/total_text/train/poly/2.txt
+0
-2
train_data/total_text/train/rgb/2.jpg
train_data/total_text/train/rgb/2.jpg
+0
-0
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
a0d1f923
...
...
@@ -11,7 +11,7 @@ Global:
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights
:
Tru
e
load_static_weights
:
Fals
e
cal_metric_during_train
:
False
pretrained_model
:
checkpoints
:
...
...
@@ -94,7 +94,7 @@ Eval:
label_file_list
:
[
./train_data/total_text/test/
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
RGB
channel_first
:
False
-
E2ELabelEncode
:
-
E2EResizeForTest
:
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
a0d1f923
...
...
@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
text_label_index_list
,
temp_text
=
[],
[]
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
text
=
text
.
lower
()
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_texts
.
append
(
text
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
temp_text
=
[]
for
c_
in
text
:
if
c_
in
self
.
dict
:
temp_text
.
append
(
self
.
dict
[
c_
])
temp_text
=
temp_text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
temp_text
))
text_label_index_list
.
append
(
temp_text
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
return
data
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
a0d1f923
...
...
@@ -24,6 +24,7 @@ class PGDataSet(Dataset):
self
.
logger
=
logger
self
.
seed
=
seed
self
.
mode
=
mode
global_config
=
config
[
'Global'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
...
...
@@ -62,10 +63,13 @@ class PGDataSet(Dataset):
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
map
(
float
,
poly_str
.
split
(
','
))
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
if
self
.
mode
.
lower
()
==
"eval"
:
while
len
(
poly
)
<
100
:
poly
.
append
(
-
1
)
text_polys
.
append
(
np
.
array
(
list
(
poly
)
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
...
...
@@ -135,8 +139,12 @@ class PGDataSet(Dataset):
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
if
self
.
mode
.
lower
()
==
"eval"
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly_gt'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
else
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
a0d1f923
...
...
@@ -33,10 +33,20 @@ class E2EMetric(object):
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
gt_polyons_batch
=
batch
[
2
]
temp_
gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
...
...
@@ -46,7 +56,7 @@ class E2EMetric(object):
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
gt_polyons_batch
,
[
gt_strs_batch
],
ignore_tags_batch
):
[
preds
],
[
gt_polyons_batch
]
,
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
a0d1f923
...
...
@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
,
generate_pivot_list
,
restore_poly
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
...
...
@@ -37,6 +38,11 @@ class PGPostProcess(object):
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
...
...
@@ -52,17 +58,96 @@ class PGPostProcess(object):
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
instance_yxs_list
,
seq_strs
=
generate_pivot_list
(
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
self
.
Lexicon_Table
,
score_thresh
=
self
.
score_thresh
)
poly_list
,
keep_str_list
=
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
self
.
valid_set
)
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
a0d1f923
...
...
@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
.
tolist
()
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
...
...
@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
...
...
@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
global_tp
=
0
global_fp
=
0
global_fn
=
0
global_sigma
=
[]
global_tau
=
[]
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_pred_str
=
[]
global_gt_str
=
[]
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for
input_id
in
range
(
allInputs
):
...
...
@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_tau
.
append
(
local_tau_table
)
global_pred_str
.
append
(
local_pred_str
)
global_gt_str
.
append
(
local_gt_str
)
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
global_sigma
=
local_sigma_table
global_tau
=
local_tau_table
global_pred_str
=
local_pred_str
global_gt_str
=
local_gt_str
single_data
=
{}
for
idx
in
range
(
len
(
global_sigma
)):
local_sigma_table
=
global_sigma
[
idx
]
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
# fid = open(fid_path, 'a+')
try
:
local_precision
=
local_accumulative_precision
/
num_det
except
ZeroDivisionError
:
local_precision
=
0
try
:
local_recall
=
local_accumulative_recall
/
num_gt
except
ZeroDivisionError
:
local_recall
=
0
try
:
local_f_score
=
2
*
local_precision
*
local_recall
/
(
local_precision
+
local_recall
)
except
ZeroDivisionError
:
local_f_score
=
0
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
single_data
[
"recall"
]
=
local_recall
single_data
[
'precision'
]
=
local_precision
single_data
[
'f_score'
]
=
local_f_score
return
single_data
...
...
@@ -435,10 +163,10 @@ def combine_results(all_data):
global_pred_str
=
[]
global_gt_str
=
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
]
[
0
]
)
global_tau
.
append
(
data
[
'global_tau'
]
[
0
]
)
global_pred_str
.
append
(
data
[
'global_pred_str'
]
[
0
]
)
global_gt_str
.
append
(
data
[
'global_gt_str'
]
[
0
]
)
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
...
...
@@ -676,6 +404,8 @@ def combine_results(all_data):
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
...
...
ppocr/utils/e2e_utils/extract_textpoint.py
浏览文件 @
a0d1f923
...
...
@@ -17,9 +17,11 @@ from __future__ import division
from
__future__
import
print_function
import
cv2
import
math
import
numpy
as
np
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
def
get_dict
(
character_dict_path
):
...
...
@@ -33,39 +35,87 @@ def get_dict(character_dict_path):
return
dict_character
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
4
):
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
probs_seq
=
logits_seq
labels
=
np
.
argmax
(
probs_seq
,
axis
=
1
)
dst_str
=
[
k
for
k
,
v_
in
groupby
(
labels
)
if
k
!=
C
-
1
]
detal
=
len
(
gather_info
)
//
(
pts_num
-
1
)
keep_idx_list
=
[
0
]
+
[
detal
*
(
i
+
1
)
for
i
in
range
(
pts_num
-
2
)]
+
[
-
1
]
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
# n x 96
probs_seq
=
softmax
(
logits_seq
)
dst_str
,
keep_idx_list
=
ctc_greedy_decoder
(
probs_seq
,
blank
=
C
-
1
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
Lexicon_Table
,
pts_num
=
6
):
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
CTC decoder using multiple processes.
"""
decoder_str
=
[]
decoder_xys
=
[]
decoder_results
=
[]
for
gather_info
in
gather_info_list
:
if
len
(
gather_info
)
<
pts_num
:
continue
dst_str
,
xys_list
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
pts_num
)
dst_str_readable
=
''
.
join
([
Lexicon_Table
[
idx
]
for
idx
in
dst_str
])
if
len
(
dst_str_readable
)
<
2
:
continue
decoder_str
.
append
(
dst_str_readable
)
decoder_xys
.
append
(
xys_list
)
return
decoder_str
,
decoder_xys
res
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
decoder_results
.
append
(
res
)
return
decoder_results
def
sort_with_direction
(
pos_list
,
f_direction
):
...
...
@@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction):
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
f_direction: h x w x 2
...
...
@@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
...
...
@@ -159,108 +262,258 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return
all_list
def
point_pair2poly
(
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
valid_set
):
poly_list
=
[]
keep_str_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
offset_expand
=
1.0
if
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
*
offset_expand
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
detected_poly
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
keep_str_list
.
append
(
keep_str
)
if
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
def
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_expand
=
True
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
if
is_expand
:
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
else
:
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map_bi
=
(
p_score
>
score_thresh
)
*
1.0
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
p_tcl_map_bi
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
5
:
continue
# add rule here
main_direction
=
extract_main_direction
(
pos_list
,
f_direction
)
# y x
reference_directin
=
np
.
array
([
0
,
1
]).
reshape
([
-
1
,
2
])
# y x
is_h_angle
=
abs
(
np
.
sum
(
main_direction
*
reference_directin
))
<
math
.
cos
(
math
.
pi
/
180
*
70
)
point_yxs
=
np
.
array
(
pos_list
)
max_y
,
max_x
=
np
.
max
(
point_yxs
,
axis
=
0
)
min_y
,
min_x
=
np
.
min
(
point_yxs
,
axis
=
0
)
is_h_len
=
(
max_y
-
min_y
)
<
1.5
*
(
max_x
-
min_x
)
pos_list_final
=
[]
if
is_h_len
:
xs
=
np
.
unique
(
xs
)
for
x
in
xs
:
ys
=
instance_label_map
[:,
x
].
copy
().
reshape
((
-
1
,
))
y
=
int
(
np
.
where
(
ys
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
else
:
ys
=
np
.
unique
(
ys
)
for
y
in
ys
:
xs
=
instance_label_map
[
y
,
:].
copy
().
reshape
((
-
1
,
))
x
=
int
(
np
.
where
(
xs
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list_final
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
return
poly_list
,
keep_str_list
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
Lexicon_Table
,
score_thresh
=
0.5
):
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
if
is_curved
:
return
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_expand
=
True
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
else
:
return
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
# for refine module
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
def
generate_pivot_list_tt_inference
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
ret
,
p_tcl_map
=
cv2
.
threshold
(
p_score
,
score_thresh
,
255
,
cv2
.
THRESH_BINARY
)
skeleton_map
=
thin
(
p_tcl_map
.
astype
(
'uint8'
))
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
,
connectivity
=
8
)
skeleton_map
.
astype
(
np
.
uint8
)
,
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
...
...
@@ -269,15 +522,11 @@ def generate_pivot_list(p_score,
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
all_pos_yxs
.
append
(
pos_list_sorted
)
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decoded_str
,
keep_yxs_list
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
Lexicon_Table
=
Lexicon_Table
)
return
keep_yxs_list
,
decoded_str
pos_list_sorted_with_id
=
add_id
(
pos_list_sorted
,
image_id
=
image_id
)
all_pos_yxs
.
append
(
pos_list_sorted_with_id
)
return
all_pos_yxs
tools/infer/predict_e2e.py
浏览文件 @
a0d1f923
...
...
@@ -151,7 +151,7 @@ if __name__ == "__main__":
src_im
=
utility
.
draw_e2e_res
(
points
,
strs
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
"e2e_res_{}
_pgnet
"
.
format
(
img_name_pure
))
"e2e_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
if
count
>
1
:
...
...
train_data/total_text/train/poly/2.txt
已删除
100644 → 0
浏览文件 @
0cd48c35
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
train_data/total_text/train/rgb/2.jpg
已删除
100644 → 0
浏览文件 @
0cd48c35
40.9 KB
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录