Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
713ceb4e
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
713ceb4e
编写于
4月 16, 2021
作者:
M
MissPenguin
提交者:
GitHub
4月 16, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2498 from JetHong/dy/fix_data_format
Dy/fix data format cherry pick
上级
c5f33b00
7607b570
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
66 addition
and
131 deletion
+66
-131
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+6
-6
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+23
-21
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+4
-4
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+24
-92
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+3
-3
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+1
-1
ppocr/utils/e2e_utils/extract_textpoint_fast.py
ppocr/utils/e2e_utils/extract_textpoint_fast.py
+1
-0
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+2
-2
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
713ceb4e
...
...
@@ -62,20 +62,21 @@ PostProcess:
mode
:
fast
# fast or slow two ways
Metric
:
name
:
E2EMetric
gt_mat_dir
:
# the dir of gt_mat
gt_mat_dir
:
./train_data/total_text/gt
# the dir of gt_mat
character_dict_path
:
ppocr/utils/ic15_dict.txt
main_indicator
:
f_score_e2e
Train
:
dataset
:
name
:
PGDataSet
label_file_list
:
[
.././train_data/total_text/train/
]
data_dir
:
./train_data/total_text/train
label_file_list
:
[
./train_data/total_text/train/
]
ratio_list
:
[
1.0
]
data_format
:
icdar
#two data format: icdar/textnet
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
E2ELabelEncode
:
-
PGProcessTrain
:
batch_size
:
14
# same as loader: batch_size_per_card
min_crop_size
:
24
...
...
@@ -92,13 +93,12 @@ Train:
Eval
:
dataset
:
name
:
PGDataSet
data_dir
:
./train_data/
data_dir
:
./train_data/
total_text/test
label_file_list
:
[
./train_data/total_text/test/
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
E2ELabelEncode
:
-
E2EResizeForTest
:
max_side_len
:
768
-
NormalizeImage
:
...
...
@@ -108,7 +108,7 @@ Eval:
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
strs'
,
'
tags'
,
'
img_id'
]
keep_keys
:
[
'
image'
,
'
shape'
,
'
img_id'
]
loader
:
shuffle
:
False
drop_last
:
False
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
713ceb4e
...
...
@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return
dict_character
class
E2ELabelEncode
(
BaseRecLabelEncode
):
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'EN'
,
use_space_char
=
False
,
**
kwargs
):
super
(
E2ELabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
class
E2ELabelEncode
(
object
):
def
__init__
(
self
,
**
kwargs
):
pass
def
__call__
(
self
,
data
):
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
text
=
text
.
lower
()
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_texts
.
append
(
text
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
import
json
label
=
data
[
'label'
]
label
=
json
.
loads
(
label
)
nBox
=
len
(
label
)
boxes
,
txts
,
txt_tags
=
[],
[],
[]
for
bno
in
range
(
0
,
nBox
):
box
=
label
[
bno
][
'points'
]
txt
=
label
[
bno
][
'transcription'
]
boxes
.
append
(
box
)
txts
.
append
(
txt
)
if
txt
in
[
'*'
,
'###'
]:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
data
[
'polys'
]
=
boxes
data
[
'texts'
]
=
txts
data
[
'ignore_tags'
]
=
txt_tags
return
data
...
...
ppocr/data/imaug/pg_process.py
浏览文件 @
713ceb4e
...
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return
min_area_quad
def
check_and_validate_polys
(
self
,
polys
,
tags
,
xxx_todo_changem
e
):
def
check_and_validate_polys
(
self
,
polys
,
tags
,
im_siz
e
):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
...
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:return:
"""
(
h
,
w
)
=
xxx_todo_changem
e
(
h
,
w
)
=
im_siz
e
if
polys
.
shape
[
0
]
==
0
:
return
polys
,
np
.
array
([]),
np
.
array
([])
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
...
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size
=
512
im
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
text_tags
=
data
[
'tags'
]
text_strs
=
data
[
'
str
s'
]
text_tags
=
data
[
'
ignore_
tags'
]
text_strs
=
data
[
'
text
s'
]
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
713ceb4e
...
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
self
.
data_format
=
dataset_config
.
get
(
'data_format'
,
'icdar'
)
assert
len
(
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
,
self
.
data_format
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
...
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
return
def
extract_polys
(
self
,
poly_txt_path
):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys
,
txt_tags
,
txts
=
[],
[],
[]
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
return
np
.
array
(
list
(
map
(
np
.
array
,
text_polys
))),
\
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
),
txts
def
extract_info_textnet
(
self
,
im_fn
,
img_dir
=
''
):
"""
Extract information from line in textnet format.
"""
info_list
=
im_fn
.
split
(
'
\t
'
)
img_path
=
''
for
ext
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]:
if
os
.
path
.
exists
(
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)):
img_path
=
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)
break
if
img_path
==
''
:
print
(
'Image {0} NOT found in {1}, and it will be ignored.'
.
format
(
info_list
[
0
],
img_dir
))
nBox
=
(
len
(
info_list
)
-
1
)
//
9
wordBBs
,
txts
,
txt_tags
=
[],
[],
[]
for
n
in
range
(
0
,
nBox
):
wordBB
=
list
(
map
(
float
,
info_list
[
n
*
9
+
1
:(
n
+
1
)
*
9
]))
txt
=
info_list
[(
n
+
1
)
*
9
]
wordBBs
.
append
([[
wordBB
[
0
],
wordBB
[
1
]],
[
wordBB
[
2
],
wordBB
[
3
]],
[
wordBB
[
4
],
wordBB
[
5
]],
[
wordBB
[
6
],
wordBB
[
7
]]])
txts
.
append
(
txt
)
if
txt
==
'###'
:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
return
img_path
,
np
.
array
(
wordBBs
,
dtype
=
np
.
float32
),
txt_tags
,
txts
def
get_image_info_list
(
self
,
file_list
,
ratio_list
,
data_format
=
'textnet'
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
data_lines
=
[]
for
idx
,
data_source
in
enumerate
(
file_list
):
image_files
=
[]
if
data_format
==
'icdar'
:
image_files
=
[(
data_source
,
x
)
for
x
in
os
.
listdir
(
os
.
path
.
join
(
data_source
,
'rgb'
))
if
x
.
split
(
'.'
)[
-
1
]
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]]
elif
data_format
==
'textnet'
:
with
open
(
data_source
)
as
f
:
image_files
=
[(
data_source
,
x
.
strip
())
for
x
in
f
.
readlines
()]
else
:
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
image_files
=
random
.
sample
(
image_files
,
round
(
len
(
image_files
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_files
)
for
idx
,
file
in
enumerate
(
file_list
):
with
open
(
file
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
random
.
seed
(
self
.
seed
)
lines
=
random
.
sample
(
lines
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
lines
)
return
data_lines
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_
path
,
data_
line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
data_line
=
data_line
.
decode
(
'utf-8'
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
file_name
=
substr
[
0
]
label
=
substr
[
1
]
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
self
.
mode
.
lower
()
==
'eval'
:
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data_line
,
image_dir
)
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
,
'img_id'
:
img_id
}
img_id
=
0
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
if
not
os
.
path
.
exists
(
img_path
):
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
713ceb4e
...
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
5
][
0
]
img_id
=
batch
[
2
][
0
]
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
str
s'
])]
'text
s
'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
text
s'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
713ceb4e
...
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
text
=
pred_dict
[
i
][
'text
s
'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
...
...
ppocr/utils/e2e_utils/extract_textpoint_fast.py
浏览文件 @
713ceb4e
...
...
@@ -21,6 +21,7 @@ import math
import
numpy
as
np
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
浏览文件 @
713ceb4e
...
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
return
data
...
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
return
data
tools/infer/predict_e2e.py
浏览文件 @
713ceb4e
...
...
@@ -122,7 +122,7 @@ class TextE2E(object):
else
:
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
return
dt_boxes
,
strs
,
elapse
...
...
tools/infer_e2e.py
浏览文件 @
713ceb4e
...
...
@@ -103,7 +103,7 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
# write resule
dt_boxes_json
=
[]
for
poly
,
str
in
zip
(
points
,
strs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录