Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
50bcec46
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
50bcec46
编写于
4月 13, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix data input format
上级
579b66b0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
58 addition
and
125 deletion
+58
-125
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+3
-2
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+23
-21
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+4
-4
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+21
-92
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+3
-3
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+1
-1
ppocr/utils/e2e_utils/extract_textpoint_fast.py
ppocr/utils/e2e_utils/extract_textpoint_fast.py
+1
-0
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+2
-2
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
50bcec46
...
@@ -69,6 +69,7 @@ Metric:
...
@@ -69,6 +69,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
PGDataSet
name
:
PGDataSet
data_dir
:
./train_data/
label_file_list
:
[
.././train_data/total_text/train/
]
label_file_list
:
[
.././train_data/total_text/train/
]
ratio_list
:
[
1.0
]
ratio_list
:
[
1.0
]
data_format
:
icdar
#two data format: icdar/textnet
data_format
:
icdar
#two data format: icdar/textnet
...
@@ -76,6 +77,7 @@ Train:
...
@@ -76,6 +77,7 @@ Train:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
PGProcessTrain
:
-
PGProcessTrain
:
batch_size
:
14
# same as loader: batch_size_per_card
batch_size
:
14
# same as loader: batch_size_per_card
min_crop_size
:
24
min_crop_size
:
24
...
@@ -98,7 +100,6 @@ Eval:
...
@@ -98,7 +100,6 @@ Eval:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
RGB
img_mode
:
RGB
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
E2EResizeForTest
:
-
E2EResizeForTest
:
max_side_len
:
768
max_side_len
:
768
-
NormalizeImage
:
-
NormalizeImage
:
...
@@ -108,7 +109,7 @@ Eval:
...
@@ -108,7 +109,7 @@ Eval:
order
:
'
hwc'
order
:
'
hwc'
-
ToCHWImage
:
-
ToCHWImage
:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
strs'
,
'
tags'
,
'
img_id'
]
keep_keys
:
[
'
image'
,
'
shape'
,
'
img_id'
]
loader
:
loader
:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
50bcec46
...
@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
...
@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return
dict_character
return
dict_character
class
E2ELabelEncode
(
BaseRecLabelEncode
):
class
E2ELabelEncode
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
**
kwargs
):
max_text_length
,
pass
character_dict_path
=
None
,
character_type
=
'EN'
,
use_space_char
=
False
,
**
kwargs
):
super
(
E2ELabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
texts
=
data
[
'strs'
]
import
json
temp_texts
=
[]
label
=
data
[
'label'
]
for
text
in
texts
:
label
=
json
.
loads
(
label
)
text
=
text
.
lower
()
nBox
=
len
(
label
)
text
=
self
.
encode
(
text
)
boxes
,
txts
,
txt_tags
=
[],
[],
[]
if
text
is
None
:
for
bno
in
range
(
0
,
nBox
):
return
None
box
=
label
[
bno
][
'points'
]
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
txt
=
label
[
bno
][
'transcription'
]
temp_texts
.
append
(
text
)
boxes
.
append
(
box
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
txts
.
append
(
txt
)
if
txt
in
[
'*'
,
'###'
]:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
data
[
'polys'
]
=
boxes
data
[
'texts'
]
=
txts
data
[
'ignore_tags'
]
=
txt_tags
return
data
return
data
...
...
ppocr/data/imaug/pg_process.py
浏览文件 @
50bcec46
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return
min_area_quad
return
min_area_quad
def
check_and_validate_polys
(
self
,
polys
,
tags
,
xxx_todo_changem
e
):
def
check_and_validate_polys
(
self
,
polys
,
tags
,
im_siz
e
):
"""
"""
check so that the text poly is in the same direction,
check so that the text poly is in the same direction,
and also filter some invalid polygons
and also filter some invalid polygons
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:param tags:
:return:
:return:
"""
"""
(
h
,
w
)
=
xxx_todo_changem
e
(
h
,
w
)
=
im_siz
e
if
polys
.
shape
[
0
]
==
0
:
if
polys
.
shape
[
0
]
==
0
:
return
polys
,
np
.
array
([]),
np
.
array
([])
return
polys
,
np
.
array
([]),
np
.
array
([])
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size
=
512
input_size
=
512
im
=
data
[
'image'
]
im
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
text_polys
=
data
[
'polys'
]
text_tags
=
data
[
'tags'
]
text_tags
=
data
[
'
ignore_
tags'
]
text_strs
=
data
[
'
str
s'
]
text_strs
=
data
[
'
text
s'
]
h
,
w
,
_
=
im
.
shape
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
text_polys
,
text_tags
,
(
h
,
w
))
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
50bcec46
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
self
.
data_format
=
dataset_config
.
get
(
'data_format'
,
'icdar'
)
assert
len
(
assert
len
(
ratio_list
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
,
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_format
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
self
.
shuffle_data_random
()
...
@@ -55,108 +55,37 @@ class PGDataSet(Dataset):
...
@@ -55,108 +55,37 @@ class PGDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
extract_polys
(
self
,
poly_txt_path
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys
,
txt_tags
,
txts
=
[],
[],
[]
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
return
np
.
array
(
list
(
map
(
np
.
array
,
text_polys
))),
\
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
),
txts
def
extract_info_textnet
(
self
,
im_fn
,
img_dir
=
''
):
"""
Extract information from line in textnet format.
"""
info_list
=
im_fn
.
split
(
'
\t
'
)
img_path
=
''
for
ext
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]:
if
os
.
path
.
exists
(
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)):
img_path
=
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)
break
if
img_path
==
''
:
print
(
'Image {0} NOT found in {1}, and it will be ignored.'
.
format
(
info_list
[
0
],
img_dir
))
nBox
=
(
len
(
info_list
)
-
1
)
//
9
wordBBs
,
txts
,
txt_tags
=
[],
[],
[]
for
n
in
range
(
0
,
nBox
):
wordBB
=
list
(
map
(
float
,
info_list
[
n
*
9
+
1
:(
n
+
1
)
*
9
]))
txt
=
info_list
[(
n
+
1
)
*
9
]
wordBBs
.
append
([[
wordBB
[
0
],
wordBB
[
1
]],
[
wordBB
[
2
],
wordBB
[
3
]],
[
wordBB
[
4
],
wordBB
[
5
]],
[
wordBB
[
6
],
wordBB
[
7
]]])
txts
.
append
(
txt
)
if
txt
==
'###'
:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
return
img_path
,
np
.
array
(
wordBBs
,
dtype
=
np
.
float32
),
txt_tags
,
txts
def
get_image_info_list
(
self
,
file_list
,
ratio_list
,
data_format
=
'textnet'
):
if
isinstance
(
file_list
,
str
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
file_list
=
[
file_list
]
data_lines
=
[]
data_lines
=
[]
for
idx
,
data_source
in
enumerate
(
file_list
):
for
idx
,
file
in
enumerate
(
file_list
):
image_files
=
[]
with
open
(
file
,
"rb"
)
as
f
:
if
data_format
==
'icdar'
:
lines
=
f
.
readlines
()
image_files
=
[(
data_source
,
x
)
for
x
in
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
os
.
listdir
(
os
.
path
.
join
(
data_source
,
'rgb'
))
random
.
seed
(
self
.
seed
)
if
x
.
split
(
'.'
)[
-
1
]
in
[
lines
=
random
.
sample
(
lines
,
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
'tiff'
,
'gif'
,
'JPG'
data_lines
.
extend
(
lines
)
]]
elif
data_format
==
'textnet'
:
with
open
(
data_source
)
as
f
:
image_files
=
[(
data_source
,
x
.
strip
())
for
x
in
f
.
readlines
()]
else
:
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
image_files
=
random
.
sample
(
image_files
,
round
(
len
(
image_files
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_files
)
return
data_lines
return
data_lines
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_
path
,
data_
line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
try
:
if
self
.
data_format
==
'icdar'
:
data_line
=
data_line
.
decode
(
'utf-8'
)
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
file_name
=
substr
[
0
]
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
label
=
substr
[
1
]
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data_line
,
image_dir
)
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
data
=
{
if
not
os
.
path
.
exists
(
img_path
):
'img_path'
:
im_path
,
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
,
'img_id'
:
img_id
}
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
img
=
f
.
read
()
data
[
'image'
]
=
img
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
50bcec46
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
5
][
0
]
img_id
=
batch
[
2
][
0
]
e2e_info_list
=
[{
e2e_info_list
=
[{
'points'
:
det_polyon
,
'points'
:
det_polyon
,
'text'
:
pred_str
'text
s
'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
str
s'
])]
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
text
s'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
self
.
results
.
append
(
result
)
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
50bcec46
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n
=
len
(
pred_dict
)
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
text
=
pred_dict
[
i
][
'text
s
'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
det
.
append
([
point
,
text
])
return
det
return
det
...
...
ppocr/utils/e2e_utils/extract_textpoint_fast.py
浏览文件 @
50bcec46
...
@@ -21,6 +21,7 @@ import math
...
@@ -21,6 +21,7 @@ import math
import
numpy
as
np
import
numpy
as
np
from
itertools
import
groupby
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
from
skimage.morphology._skeletonize
import
thin
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
浏览文件 @
50bcec46
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w
,
src_h
,
self
.
valid_set
)
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit
(
-
1
)
exit
(
-
1
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录