Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
38f27a53
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
38f27a53
编写于
12月 07, 2020
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
差异文件
merge upstream
上级
cb7afb85
99ee41d8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
63 addition
and
80 deletion
+63
-80
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+17
-49
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+10
-10
tools/infer/predict_det.py
tools/infer/predict_det.py
+3
-2
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+5
-5
tools/infer/predict_system.py
tools/infer/predict_system.py
+28
-14
未找到文件。
ppocr/data/simple_dataset.py
浏览文件 @
38f27a53
...
...
@@ -32,12 +32,10 @@ class SimpleDataSet(Dataset):
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
if
data_source_num
==
1
:
ratio_list
=
[
1.0
]
else
:
ratio_list
=
dataset_config
.
pop
(
'ratio_list'
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
len
(
data_source_num
)
assert
sum
(
ratio_list
)
==
1
,
"The sum of the ratio_list should be 1."
assert
len
(
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
...
...
@@ -45,62 +43,32 @@ class SimpleDataSet(Dataset):
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines_list
,
data_num_list
=
self
.
get_image_info_list
(
label_file_list
)
self
.
data_idx_order_list
=
self
.
dataset_traversal
(
data_num_list
,
ratio_list
,
batch_size
)
self
.
shuffle_data_random
()
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
def
get_image_info_list
(
self
,
file_list
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
data_lines_list
=
[]
data_num_list
=
[]
for
file
in
file_list
:
data_lines
=
[]
for
idx
,
file
in
enumerate
(
file_list
):
with
open
(
file
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
data_lines_list
.
append
(
lines
)
data_num_list
.
append
(
len
(
lines
))
return
data_lines_list
,
data_num_list
def
dataset_traversal
(
self
,
data_num_list
,
ratio_list
,
batch_size
):
select_num_list
=
[]
dataset_num
=
len
(
data_num_list
)
for
dno
in
range
(
dataset_num
):
select_num
=
round
(
batch_size
*
ratio_list
[
dno
])
select_num
=
max
(
select_num
,
1
)
select_num_list
.
append
(
select_num
)
data_idx_order_list
=
[]
cur_index_sets
=
[
0
]
*
dataset_num
while
True
:
finish_read_num
=
0
for
dataset_idx
in
range
(
dataset_num
):
cur_index
=
cur_index_sets
[
dataset_idx
]
if
cur_index
>=
data_num_list
[
dataset_idx
]:
finish_read_num
+=
1
else
:
select_num
=
select_num_list
[
dataset_idx
]
for
sno
in
range
(
select_num
):
cur_index
=
cur_index_sets
[
dataset_idx
]
if
cur_index
>=
data_num_list
[
dataset_idx
]:
break
data_idx_order_list
.
append
((
dataset_idx
,
cur_index
))
cur_index_sets
[
dataset_idx
]
+=
1
if
finish_read_num
==
dataset_num
:
break
return
data_idx_order_list
lines
=
random
.
sample
(
lines
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
lines
)
return
data_lines
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
for
dno
in
range
(
len
(
self
.
data_lines_list
)):
random
.
shuffle
(
self
.
data_lines_list
[
dno
])
random
.
shuffle
(
self
.
data_lines
)
return
def
__getitem__
(
self
,
idx
):
dataset_idx
,
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_line
=
self
.
data_lines
_list
[
dataset_idx
]
[
file_idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
data_line
=
data_line
.
decode
(
'utf-8'
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
...
...
tools/infer/predict_cls.py
浏览文件 @
38f27a53
...
...
@@ -23,7 +23,7 @@ import copy
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
...
...
@@ -106,10 +106,10 @@ class TextClassifier(object):
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
prob_out
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
cls_res
=
self
.
postprocess_op
(
prob_out
)
cls_res
ult
=
self
.
postprocess_op
(
prob_out
)
elapse
+=
time
.
time
()
-
starttime
for
rno
in
range
(
len
(
cls_res
)):
label
,
score
=
cls_res
[
rno
]
for
rno
in
range
(
len
(
cls_res
ult
)):
label
,
score
=
cls_res
ult
[
rno
]
cls_res
[
indices
[
beg_img_no
+
rno
]]
=
[
label
,
score
]
if
'180'
in
label
and
score
>
self
.
cls_thresh
:
img_list
[
indices
[
beg_img_no
+
rno
]]
=
cv2
.
rotate
(
...
...
@@ -133,8 +133,8 @@ def main(args):
img_list
.
append
(
img
)
try
:
img_list
,
cls_res
,
predict_time
=
text_classifier
(
img_list
)
except
Exception
as
e
:
print
(
e
)
except
:
logger
.
info
(
traceback
.
format_exc
()
)
logger
.
info
(
"ERROR!!!!
\n
"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq
\n
"
...
...
@@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
cls_res
[
logger
.
info
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
cls_res
[
ino
]))
print
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
logger
.
info
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
tools/infer/predict_det.py
浏览文件 @
38f27a53
...
...
@@ -178,11 +178,12 @@ if __name__ == "__main__":
if
count
>
0
:
total_time
+=
elapse
count
+=
1
print
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
"det_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
if
count
>
1
:
print
(
"Avg Time:"
,
total_time
/
(
count
-
1
))
logger
.
info
(
"Avg Time:"
,
total_time
/
(
count
-
1
))
tools/infer/predict_rec.py
浏览文件 @
38f27a53
...
...
@@ -22,7 +22,7 @@ import cv2
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
...
...
@@ -135,8 +135,8 @@ def main(args):
img_list
.
append
(
img
)
try
:
rec_res
,
predict_time
=
text_recognizer
(
img_list
)
except
Exception
as
e
:
print
(
e
)
except
:
logger
.
info
(
traceback
.
format_exc
()
)
logger
.
info
(
"ERROR!!!!
\n
"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq
\n
"
...
...
@@ -145,9 +145,9 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
logger
.
info
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
logger
.
info
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
len
(
img_list
),
predict_time
))
...
...
tools/infer/predict_system.py
浏览文件 @
38f27a53
...
...
@@ -23,17 +23,21 @@ import numpy as np
import
time
from
PIL
import
Image
import
tools.infer.utility
as
utility
from
tools.infer.utility
import
draw_ocr
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_det
as
predict_det
import
tools.infer.predict_cls
as
predict_cls
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
tools.infer.utility
import
draw_ocr_box_txt
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
if
self
.
use_angle_cls
:
self
.
text_classifier
=
predict_cls
.
TextClassifier
(
args
)
def
get_rotate_crop_image
(
self
,
img
,
points
):
'''
...
...
@@ -72,12 +76,13 @@ class TextSystem(object):
bbox_num
=
len
(
img_crop_list
)
for
bno
in
range
(
bbox_num
):
cv2
.
imwrite
(
"./output/img_crop_%d.jpg"
%
bno
,
img_crop_list
[
bno
])
print
(
bno
,
rec_res
[
bno
])
logger
.
info
(
bno
,
rec_res
[
bno
])
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
print
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
...
...
@@ -88,8 +93,15 @@ class TextSystem(object):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
if
self
.
use_angle_cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
logger
.
info
(
"cls num : {}, elapse : {}"
.
format
(
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
print
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return
dt_boxes
,
rec_res
...
...
@@ -119,7 +131,8 @@ def main(args):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_sys
=
TextSystem
(
args
)
is_visualize
=
True
tackle_img_num
=
0
font_path
=
args
.
vis_font_path
drop_score
=
args
.
drop_score
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
...
...
@@ -128,20 +141,16 @@ def main(args):
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
tackle_img_num
+=
1
if
not
args
.
use_gpu
and
args
.
enable_mkldnn
and
tackle_img_num
%
30
==
0
:
text_sys
=
TextSystem
(
args
)
dt_boxes
,
rec_res
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
print
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
drop_score
=
0.5
dt_num
=
len
(
dt_boxes
)
for
dno
in
range
(
dt_num
):
text
,
score
=
rec_res
[
dno
]
if
score
>=
drop_score
:
text_str
=
"%s, %.3f"
%
(
text
,
score
)
print
(
text_str
)
logger
.
info
(
text_str
)
if
is_visualize
:
image
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
...
...
@@ -149,15 +158,20 @@ def main(args):
txts
=
[
rec_res
[
i
][
0
]
for
i
in
range
(
len
(
rec_res
))]
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
draw_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
)
draw_img
=
draw_ocr_box_txt
(
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
,
font_path
=
font_path
)
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
print
(
"The visualized image saved in {}"
.
format
(
logger
.
info
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录