Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
a0d1f923
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0d1f923
编写于
4月 09, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add different post process
上级
0cd48c35
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
516 addition
and
434 deletion
+516
-434
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+2
-2
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+9
-7
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+12
-4
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+12
-2
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+93
-8
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+16
-286
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+371
-122
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
train_data/total_text/train/poly/2.txt
train_data/total_text/train/poly/2.txt
+0
-2
train_data/total_text/train/rgb/2.jpg
train_data/total_text/train/rgb/2.jpg
+0
-0
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
a0d1f923
...
@@ -11,7 +11,7 @@ Global:
...
@@ -11,7 +11,7 @@ Global:
# from static branch, load_static_weights must be set as True.
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
# you should set load_static_weights as False.
load_static_weights
:
Tru
e
load_static_weights
:
Fals
e
cal_metric_during_train
:
False
cal_metric_during_train
:
False
pretrained_model
:
pretrained_model
:
checkpoints
:
checkpoints
:
...
@@ -94,7 +94,7 @@ Eval:
...
@@ -94,7 +94,7 @@ Eval:
label_file_list
:
[
./train_data/total_text/test/
]
label_file_list
:
[
./train_data/total_text/test/
]
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
RGB
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
E2ELabelEncode
:
-
E2EResizeForTest
:
-
E2EResizeForTest
:
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
a0d1f923
...
@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
...
@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
text_label_index_list
,
temp_text
=
[],
[]
texts
=
data
[
'strs'
]
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
for
text
in
texts
:
text
=
text
.
lower
()
text
=
text
.
lower
()
text
=
self
.
encode
(
text
)
temp_text
=
[]
if
text
is
None
:
for
c_
in
text
:
return
None
if
c_
in
self
.
dict
:
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_text
.
append
(
self
.
dict
[
c_
])
temp_texts
.
append
(
text
)
temp_text
=
temp_text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
len
(
temp_text
))
text_label_index_list
.
append
(
temp_text
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
return
data
return
data
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
a0d1f923
...
@@ -24,6 +24,7 @@ class PGDataSet(Dataset):
...
@@ -24,6 +24,7 @@ class PGDataSet(Dataset):
self
.
logger
=
logger
self
.
logger
=
logger
self
.
seed
=
seed
self
.
seed
=
seed
self
.
mode
=
mode
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
...
@@ -62,10 +63,13 @@ class PGDataSet(Dataset):
...
@@ -62,10 +63,13 @@ class PGDataSet(Dataset):
with
open
(
poly_txt_path
)
as
f
:
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
map
(
float
,
poly_str
.
split
(
','
))
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
if
self
.
mode
.
lower
()
==
"eval"
:
while
len
(
poly
)
<
100
:
poly
.
append
(
-
1
)
text_polys
.
append
(
text_polys
.
append
(
np
.
array
(
np
.
array
(
list
(
poly
)
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
txt_tags
.
append
(
txt
==
'###'
)
...
@@ -135,6 +139,10 @@ class PGDataSet(Dataset):
...
@@ -135,6 +139,10 @@ class PGDataSet(Dataset):
try
:
try
:
if
self
.
data_format
==
'icdar'
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
if
self
.
mode
.
lower
()
==
"eval"
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly_gt'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
else
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
a0d1f923
...
@@ -33,10 +33,20 @@ class E2EMetric(object):
...
@@ -33,10 +33,20 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
gt_polyons_batch
=
batch
[
2
]
temp_
gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
t
=
""
...
@@ -46,7 +56,7 @@ class E2EMetric(object):
...
@@ -46,7 +56,7 @@ class E2EMetric(object):
gt_strs_batch
.
append
(
t
)
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
gt_polyons_batch
,
[
gt_strs_batch
],
ignore_tags_batch
):
[
preds
],
[
gt_polyons_batch
]
,
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
# prepare gt
gt_info_list
=
[{
gt_info_list
=
[{
'points'
:
gt_polyon
,
'points'
:
gt_polyon
,
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
a0d1f923
...
@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
...
@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
,
generate_pivot_list
,
restore_poly
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
import
paddle
...
@@ -37,6 +38,11 @@ class PGPostProcess(object):
...
@@ -37,6 +38,11 @@ class PGPostProcess(object):
self
.
valid_set
=
valid_set
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
self
.
score_thresh
=
score_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
p_border
=
outs_dict
[
'f_border'
]
...
@@ -52,17 +58,96 @@ class PGPostProcess(object):
...
@@ -52,17 +58,96 @@ class PGPostProcess(object):
p_border
=
p_border
[
0
]
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
instance_yxs_list
,
seq_strs
=
generate_pivot_list
(
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_score
,
p_char
,
p_char
,
p_direction
,
p_direction
,
self
.
Lexicon_Table
,
score_thresh
=
self
.
score_thresh
,
score_thresh
=
self
.
score_thresh
)
is_backbone
=
True
,
poly_list
,
keep_str_list
=
restore_poly
(
instance_yxs_list
,
seq_strs
,
is_curved
=
is_curved
)
p_border
,
ratio_w
,
ratio_h
,
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
src_w
,
src_h
,
self
.
valid_set
)
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
'strs'
:
keep_str_list
,
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
a0d1f923
...
@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
...
@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
gt
=
[]
gt
=
[]
n
=
len
(
gt_dict
)
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
.
tolist
()
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
xx
=
[
...
@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
...
@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
t_y
.
append
(
points
[
j
][
1
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
gt
.
append
(
xx
)
...
@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
...
@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
area
(
det_x
,
det_y
)),
2
)
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
##############################Initialization###################################
global_tp
=
0
# global_sigma = []
global_fp
=
0
# global_tau = []
global_fn
=
0
# global_pred_str = []
global_sigma
=
[]
# global_gt_str = []
global_tau
=
[]
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_pred_str
=
[]
global_gt_str
=
[]
###############################################################################
###############################################################################
for
input_id
in
range
(
allInputs
):
for
input_id
in
range
(
allInputs
):
...
@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
...
@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
local_pred_str
[
det_id
]
=
pred_seq_str
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_sigma
=
local_sigma_table
global_tau
.
append
(
local_tau_table
)
global_tau
=
local_tau_table
global_pred_str
.
append
(
local_pred_str
)
global_pred_str
=
local_pred_str
global_gt_str
.
append
(
local_gt_str
)
global_gt_str
=
local_gt_str
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
single_data
=
{}
single_data
=
{}
for
idx
in
range
(
len
(
global_sigma
)):
local_sigma_table
=
global_sigma
[
idx
]
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
# fid = open(fid_path, 'a+')
try
:
local_precision
=
local_accumulative_precision
/
num_det
except
ZeroDivisionError
:
local_precision
=
0
try
:
local_recall
=
local_accumulative_recall
/
num_gt
except
ZeroDivisionError
:
local_recall
=
0
try
:
local_f_score
=
2
*
local_precision
*
local_recall
/
(
local_precision
+
local_recall
)
except
ZeroDivisionError
:
local_f_score
=
0
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
single_data
[
'global_gt_str'
]
=
global_gt_str
single_data
[
"recall"
]
=
local_recall
single_data
[
'precision'
]
=
local_precision
single_data
[
'f_score'
]
=
local_f_score
return
single_data
return
single_data
...
@@ -435,10 +163,10 @@ def combine_results(all_data):
...
@@ -435,10 +163,10 @@ def combine_results(all_data):
global_pred_str
=
[]
global_pred_str
=
[]
global_gt_str
=
[]
global_gt_str
=
[]
for
data
in
all_data
:
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
]
[
0
]
)
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
]
[
0
]
)
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
]
[
0
]
)
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
]
[
0
]
)
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_recall
=
0
global_accumulative_precision
=
0
global_accumulative_precision
=
0
...
@@ -676,6 +404,8 @@ def combine_results(all_data):
...
@@ -676,6 +404,8 @@ def combine_results(all_data):
local_accumulative_recall
,
local_accumulative_precision
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
try
:
try
:
recall
=
global_accumulative_recall
/
total_num_gt
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
except
ZeroDivisionError
:
...
...
ppocr/utils/e2e_utils/extract_textpoint.py
浏览文件 @
a0d1f923
...
@@ -17,9 +17,11 @@ from __future__ import division
...
@@ -17,9 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
cv2
import
cv2
import
math
import
numpy
as
np
import
numpy
as
np
from
itertools
import
groupby
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
def
get_dict
(
character_dict_path
):
def
get_dict
(
character_dict_path
):
...
@@ -33,39 +35,87 @@ def get_dict(character_dict_path):
...
@@ -33,39 +35,87 @@ def get_dict(character_dict_path):
return
dict_character
return
dict_character
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
4
):
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_
,
_
,
C
=
logits_map
.
shape
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
# n x 96
probs_seq
=
logits_seq
probs_seq
=
softmax
(
logits_seq
)
labels
=
np
.
argmax
(
probs_seq
,
axis
=
1
)
dst_str
,
keep_idx_list
=
ctc_greedy_decoder
(
dst_str
=
[
k
for
k
,
v_
in
groupby
(
labels
)
if
k
!=
C
-
1
]
probs_seq
,
blank
=
C
-
1
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
detal
=
len
(
gather_info
)
//
(
pts_num
-
1
)
keep_idx_list
=
[
0
]
+
[
detal
*
(
i
+
1
)
for
i
in
range
(
pts_num
-
2
)]
+
[
-
1
]
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
logits_map
,
keep_blank_in_idxs
=
True
):
Lexicon_Table
,
pts_num
=
6
):
"""
"""
CTC decoder using multiple processes.
CTC decoder using multiple processes.
"""
"""
decoder_str
=
[]
decoder_results
=
[]
decoder_xys
=
[]
for
gather_info
in
gather_info_list
:
for
gather_info
in
gather_info_list
:
if
len
(
gather_info
)
<
pts_num
:
res
=
instance_ctc_greedy_decoder
(
continue
gather_info
,
logits_map
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
dst_str
,
xys_list
=
instance_ctc_greedy_decoder
(
decoder_results
.
append
(
res
)
gather_info
,
logits_map
,
pts_num
=
pts_num
)
return
decoder_results
dst_str_readable
=
''
.
join
([
Lexicon_Table
[
idx
]
for
idx
in
dst_str
])
if
len
(
dst_str_readable
)
<
2
:
continue
decoder_str
.
append
(
dst_str_readable
)
decoder_xys
.
append
(
xys_list
)
return
decoder_str
,
decoder_xys
def
sort_with_direction
(
pos_list
,
f_direction
):
def
sort_with_direction
(
pos_list
,
f_direction
):
...
@@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction):
...
@@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction):
return
sorted_point
,
np
.
array
(
sorted_direction
)
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
"""
f_direction: h x w x 2
f_direction: h x w x 2
...
@@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
...
@@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
h
,
w
,
_
=
f_direction
.
shape
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
left_direction
=
point_direction
[:
sub_direction_len
,
:]
...
@@ -159,125 +262,271 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
...
@@ -159,125 +262,271 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return
all_list
return
all_list
def
point_pair2poly
(
point_pair_list
):
def
generate_pivot_list_curved
(
p_score
,
"""
p_char_maps
,
Transfer vertical point_pairs into poly point in clockwise.
f_direction
,
"""
score_thresh
=
0.5
,
point_num
=
len
(
point_pair_list
)
*
2
is_expand
=
True
,
point_list
=
[
0
]
*
point_num
is_backbone
=
False
,
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
image_id
=
0
):
point_list
[
idx
]
=
point_pair
[
0
]
"""
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return center point and end point of TCL instance; filter with the char maps;
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
ratio_pair
=
np
.
array
(
skeleton_map
=
thin
(
p_tcl_map
)
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
end_points_yxs
=
[]
"""
instance_center_pos_yxs
=
[]
expand poly along width.
if
instance_count
>
0
:
"""
for
instance_id
in
range
(
1
,
instance_count
):
point_num
=
poly
.
shape
[
0
]
pos_list
=
[]
left_quad
=
np
.
array
(
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
pos_list
=
list
(
zip
(
ys
,
xs
))
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
### FIX-ME, eliminate outlier
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
if
len
(
pos_list
)
<
3
:
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
valid_set
):
poly_list
=
[]
keep_str_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
continue
offset_expand
=
1.0
if
is_expand
:
if
valid_set
==
'totaltext'
:
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
offset_expand
=
1.2
pos_list
,
f_direction
,
p_tcl_map
)
point_pair_list
=
[]
for
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
*
offset_expand
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
detected_poly
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
keep_str_list
.
append
(
keep_str
)
if
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
else
:
print
(
'--> Not supported format.'
)
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list
,
f_direction
)
exit
(
-
1
)
all_pos_yxs
.
append
(
pos_list_sorted
)
return
poly_list
,
keep_str_list
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
def
generate_pivot_list
(
p_score
,
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
p_char_maps
,
f_direction
,
f_direction
,
Lexicon_Table
,
score_thresh
=
0.5
,
score_thresh
=
0.5
):
is_backbone
=
False
,
image_id
=
0
):
"""
"""
return center point and end point of TCL instance; filter with the char maps;
return center point and end point of TCL instance; filter with the char maps;
"""
"""
p_score
=
p_score
[
0
]
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
ret
,
p_tcl_map
=
cv2
.
threshold
(
p_score
,
score_thresh
,
255
,
p_tcl_map_bi
=
(
p_score
>
score_thresh
)
*
1.0
cv2
.
THRESH_BINARY
)
skeleton_map
=
thin
(
p_tcl_map
.
astype
(
'uint8'
))
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
,
connectivity
=
8
)
p_tcl_map_bi
.
astype
(
np
.
uint8
)
,
connectivity
=
8
)
# get TCL Instance
# get TCL Instance
all_pos_yxs
=
[]
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
pos_list
=
list
(
zip
(
ys
,
xs
))
if
len
(
pos_list
)
<
3
:
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
5
:
continue
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
# add rule here
pos_list
,
f_direction
,
p_tcl_map
)
main_direction
=
extract_main_direction
(
pos_list
,
f_direction
)
# y x
reference_directin
=
np
.
array
([
0
,
1
]).
reshape
([
-
1
,
2
])
# y x
is_h_angle
=
abs
(
np
.
sum
(
main_direction
*
reference_directin
))
<
math
.
cos
(
math
.
pi
/
180
*
70
)
point_yxs
=
np
.
array
(
pos_list
)
max_y
,
max_x
=
np
.
max
(
point_yxs
,
axis
=
0
)
min_y
,
min_x
=
np
.
min
(
point_yxs
,
axis
=
0
)
is_h_len
=
(
max_y
-
min_y
)
<
1.5
*
(
max_x
-
min_x
)
pos_list_final
=
[]
if
is_h_len
:
xs
=
np
.
unique
(
xs
)
for
x
in
xs
:
ys
=
instance_label_map
[:,
x
].
copy
().
reshape
((
-
1
,
))
y
=
int
(
np
.
where
(
ys
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
else
:
ys
=
np
.
unique
(
ys
)
for
y
in
ys
:
xs
=
instance_label_map
[
y
,
:].
copy
().
reshape
((
-
1
,
))
x
=
int
(
np
.
where
(
xs
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list_final
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decoded_str
,
keep_yxs_list
=
ctc_decoder_for_image
(
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
Lexicon_Table
=
Lexicon_Table
)
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
return
keep_yxs_list
,
decoded_str
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
if
is_curved
:
return
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_expand
=
True
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
else
:
return
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
# for refine module
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
def
generate_pivot_list_tt_inference
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
pos_list_sorted_with_id
=
add_id
(
pos_list_sorted
,
image_id
=
image_id
)
all_pos_yxs
.
append
(
pos_list_sorted_with_id
)
return
all_pos_yxs
tools/infer/predict_e2e.py
浏览文件 @
a0d1f923
...
@@ -151,7 +151,7 @@ if __name__ == "__main__":
...
@@ -151,7 +151,7 @@ if __name__ == "__main__":
src_im
=
utility
.
draw_e2e_res
(
points
,
strs
,
image_file
)
src_im
=
utility
.
draw_e2e_res
(
points
,
strs
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
img_path
=
os
.
path
.
join
(
draw_img_save
,
"e2e_res_{}
_pgnet
"
.
format
(
img_name_pure
))
"e2e_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
if
count
>
1
:
if
count
>
1
:
...
...
train_data/total_text/train/poly/2.txt
已删除
100644 → 0
浏览文件 @
0cd48c35
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
train_data/total_text/train/rgb/2.jpg
已删除
100644 → 0
浏览文件 @
0cd48c35
40.9 KB
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录