Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
718b8ca4
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
718b8ca4
编写于
4月 23, 2021
作者:
M
MissPenguin
提交者:
GitHub
4月 23, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2579 from JetHong/dy/add_eval_mode
Dy/add eval mode
上级
fd51ba3d
a230932c
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
237 addition
and
17 deletion
+237
-17
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+7
-4
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+45
-1
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+5
-3
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+42
-8
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+138
-1
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
718b8ca4
...
@@ -60,8 +60,10 @@ PostProcess:
...
@@ -60,8 +60,10 @@ PostProcess:
name
:
PGPostProcess
name
:
PGPostProcess
score_thresh
:
0.5
score_thresh
:
0.5
mode
:
fast
# fast or slow two ways
mode
:
fast
# fast or slow two ways
Metric
:
Metric
:
name
:
E2EMetric
name
:
E2EMetric
mode
:
A
# two ways for eval, A: label from txt, B: label from gt_mat
gt_mat_dir
:
./train_data/total_text/gt
# the dir of gt_mat
gt_mat_dir
:
./train_data/total_text/gt
# the dir of gt_mat
character_dict_path
:
ppocr/utils/ic15_dict.txt
character_dict_path
:
ppocr/utils/ic15_dict.txt
main_indicator
:
f_score_e2e
main_indicator
:
f_score_e2e
...
@@ -70,13 +72,13 @@ Train:
...
@@ -70,13 +72,13 @@ Train:
dataset
:
dataset
:
name
:
PGDataSet
name
:
PGDataSet
data_dir
:
./train_data/total_text/train
data_dir
:
./train_data/total_text/train
label_file_list
:
[
./train_data/total_text/train/
]
label_file_list
:
[
./train_data/total_text/train/
total_text.txt
]
ratio_list
:
[
1.0
]
ratio_list
:
[
1.0
]
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
E2ELabelEncode
Train
:
-
PGProcessTrain
:
-
PGProcessTrain
:
batch_size
:
14
# same as loader: batch_size_per_card
batch_size
:
14
# same as loader: batch_size_per_card
min_crop_size
:
24
min_crop_size
:
24
...
@@ -94,11 +96,12 @@ Eval:
...
@@ -94,11 +96,12 @@ Eval:
dataset
:
dataset
:
name
:
PGDataSet
name
:
PGDataSet
data_dir
:
./train_data/total_text/test
data_dir
:
./train_data/total_text/test
label_file_list
:
[
./train_data/total_text/test/
]
label_file_list
:
[
./train_data/total_text/test/
total_text.txt
]
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
RGB
img_mode
:
RGB
channel_first
:
False
channel_first
:
False
-
E2ELabelEncodeTest
:
-
E2EResizeForTest
:
-
E2EResizeForTest
:
max_side_len
:
768
max_side_len
:
768
-
NormalizeImage
:
-
NormalizeImage
:
...
@@ -108,7 +111,7 @@ Eval:
...
@@ -108,7 +111,7 @@ Eval:
order
:
'
hwc'
order
:
'
hwc'
-
ToCHWImage
:
-
ToCHWImage
:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
img_id'
]
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
texts'
,
'
ignore_tags'
,
'
img_id'
]
loader
:
loader
:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
718b8ca4
...
@@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode):
...
@@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode):
return
dict_character
return
dict_character
class
E2ELabelEncode
(
object
):
class
E2ELabelEncodeTest
(
BaseRecLabelEncode
):
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'EN'
,
use_space_char
=
False
,
**
kwargs
):
super
(
E2ELabelEncodeTest
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
__call__
(
self
,
data
):
import
json
padnum
=
len
(
self
.
dict
)
label
=
data
[
'label'
]
label
=
json
.
loads
(
label
)
nBox
=
len
(
label
)
boxes
,
txts
,
txt_tags
=
[],
[],
[]
for
bno
in
range
(
0
,
nBox
):
box
=
label
[
bno
][
'points'
]
txt
=
label
[
bno
][
'transcription'
]
boxes
.
append
(
box
)
txts
.
append
(
txt
)
if
txt
in
[
'*'
,
'###'
]:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
data
[
'polys'
]
=
boxes
data
[
'ignore_tags'
]
=
txt_tags
temp_texts
=
[]
for
text
in
txts
:
text
=
text
.
lower
()
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
padnum
]
*
(
self
.
max_text_len
-
len
(
text
)
)
# use 36 to pad
temp_texts
.
append
(
text
)
data
[
'texts'
]
=
np
.
array
(
temp_texts
)
return
data
class
E2ELabelEncodeTrain
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
pass
pass
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
718b8ca4
...
@@ -72,6 +72,7 @@ class PGDataSet(Dataset):
...
@@ -72,6 +72,7 @@ class PGDataSet(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
img_id
=
0
try
:
try
:
data_line
=
data_line
.
decode
(
'utf-8'
)
data_line
=
data_line
.
decode
(
'utf-8'
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
...
@@ -79,8 +80,9 @@ class PGDataSet(Dataset):
...
@@ -79,8 +80,9 @@ class PGDataSet(Dataset):
label
=
substr
[
1
]
label
=
substr
[
1
]
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
self
.
mode
.
lower
()
==
'eval'
:
if
self
.
mode
.
lower
()
==
'eval'
:
try
:
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
else
:
except
:
img_id
=
0
img_id
=
0
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
if
not
os
.
path
.
exists
(
img_path
):
if
not
os
.
path
.
exists
(
img_path
):
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
718b8ca4
...
@@ -18,16 +18,18 @@ from __future__ import print_function
...
@@ -18,16 +18,18 @@ from __future__ import print_function
__all__
=
[
'E2EMetric'
]
__all__
=
[
'E2EMetric'
]
from
ppocr.utils.e2e_metric.Deteval
import
get_socre
,
combine_results
from
ppocr.utils.e2e_metric.Deteval
import
get_socre
_A
,
get_socre_B
,
combine_results
from
ppocr.utils.e2e_utils.extract_textpoint_slow
import
get_dict
from
ppocr.utils.e2e_utils.extract_textpoint_slow
import
get_dict
class
E2EMetric
(
object
):
class
E2EMetric
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
mode
,
gt_mat_dir
,
gt_mat_dir
,
character_dict_path
,
character_dict_path
,
main_indicator
=
'f_score_e2e'
,
main_indicator
=
'f_score_e2e'
,
**
kwargs
):
**
kwargs
):
self
.
mode
=
mode
self
.
gt_mat_dir
=
gt_mat_dir
self
.
gt_mat_dir
=
gt_mat_dir
self
.
label_list
=
get_dict
(
character_dict_path
)
self
.
label_list
=
get_dict
(
character_dict_path
)
self
.
max_index
=
len
(
self
.
label_list
)
self
.
max_index
=
len
(
self
.
label_list
)
...
@@ -35,12 +37,44 @@ class E2EMetric(object):
...
@@ -35,12 +37,44 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
2
][
0
]
if
self
.
mode
==
'A'
:
gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
][
0
]
ignore_tags_batch
=
batch
[
4
]
gt_strs_batch
=
[]
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
for
index
in
temp_list
:
if
index
<
self
.
max_index
:
t
+=
self
.
label_list
[
index
]
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
gt_polyons_batch
,
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
'text'
:
gt_str
,
'ignore'
:
ignore_tag
}
for
gt_polyon
,
gt_str
,
ignore_tag
in
zip
(
gt_polyons
,
gt_strs
,
ignore_tags
)]
# prepare det
e2e_info_list
=
[{
'points'
:
det_polyon
,
'texts'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
pred
[
'points'
],
pred
[
'texts'
])]
result
=
get_socre_A
(
gt_info_list
,
e2e_info_list
)
self
.
results
.
append
(
result
)
else
:
img_id
=
batch
[
5
][
0
]
e2e_info_list
=
[{
e2e_info_list
=
[{
'points'
:
det_polyon
,
'points'
:
det_polyon
,
'texts'
:
pred_str
'texts'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'texts'
])]
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'texts'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
result
=
get_socre_B
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
def
get_metric
(
self
):
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
718b8ca4
...
@@ -17,7 +17,144 @@ import scipy.io as io
...
@@ -17,7 +17,144 @@ import scipy.io as io
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
def
get_socre
(
gt_dir
,
img_id
,
pred_dict
):
def
get_socre_A
(
gt_dir
,
pred_dict
):
allInputs
=
1
def
input_reading_mod
(
pred_dict
):
"""This helper reads input from txt files"""
det
=
[]
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'texts'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
):
"""This helper reads groundtruths from mat files"""
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
].
tolist
()
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
(
gt
[
5
]
==
'#'
)
and
(
gt
[
1
].
shape
[
1
]
>
1
):
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
det_gt_iou
=
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
)
if
det_gt_iou
>
threshold
:
detections
[
det_id
]
=
[]
detections
[:]
=
[
item
for
item
in
detections
if
item
!=
[]]
return
detections
def
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
sigma = inter_area / gt_area
"""
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
gt_x
,
gt_y
)),
2
)
def
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
if
area
(
det_x
,
det_y
)
==
0.0
:
return
0
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for
input_id
in
range
(
allInputs
):
if
(
input_id
!=
'.DS_Store'
)
and
(
input_id
!=
'Pascal_result.txt'
)
and
(
input_id
!=
'Pascal_result_curved.txt'
)
and
(
input_id
!=
'Pascal_result_non_curved.txt'
)
and
(
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
detections
=
input_reading_mod
(
pred_dict
)
groundtruths
=
gt_reading_mod
(
gt_dir
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
dc_id
=
[]
for
i
in
range
(
len
(
groundtruths
)):
if
groundtruths
[
i
][
5
]
==
'#'
:
dc_id
.
append
(
i
)
cnt
=
0
for
a
in
dc_id
:
num
=
a
-
cnt
del
groundtruths
[
num
]
cnt
+=
1
local_sigma_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_tau_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_pred_str
=
{}
local_gt_str
=
{}
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
len
(
detections
)
>
0
:
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
pred_seq_str
=
detection_orig
[
1
].
strip
()
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
gt_seq_str
=
str
(
gt
[
4
].
tolist
()[
0
])
local_sigma_table
[
gt_id
,
det_id
]
=
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_tau_table
[
gt_id
,
det_id
]
=
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
=
local_sigma_table
global_tau
=
local_tau_table
global_pred_str
=
local_pred_str
global_gt_str
=
local_gt_str
single_data
=
{}
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
return
single_data
def
get_socre_B
(
gt_dir
,
img_id
,
pred_dict
):
allInputs
=
1
allInputs
=
1
def
input_reading_mod
(
pred_dict
):
def
input_reading_mod
(
pred_dict
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录