Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
81d8d190
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1532
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
81d8d190
编写于
5月 15, 2020
作者:
D
dyning
提交者:
GitHub
5月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #36 from LDOUBLEV/fixocr
valid det inference
上级
3b40c32a
d539508e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
16 deletion
+37
-16
configs/det/det_db_mv3.yml
configs/det/det_db_mv3.yml
+1
-1
ppocr/data/det/dataset_traversal.py
ppocr/data/det/dataset_traversal.py
+2
-2
tools/infer/predict_det.py
tools/infer/predict_det.py
+14
-2
tools/infer/utility.py
tools/infer/utility.py
+2
-2
tools/infer_det.py
tools/infer_det.py
+18
-9
未找到文件。
configs/det/det_db_mv3.yml
浏览文件 @
81d8d190
...
...
@@ -11,7 +11,7 @@ Global:
test_batch_size_per_card
:
16
image_shape
:
[
3
,
640
,
640
]
reader_yml
:
./configs/det/det_db_icdar15_reader.yml
pretrain_weights
:
./pretrain_models/MobileNetV3_
pretrained/MobileNetV3_
large_x0_5_pretrained/
pretrain_weights
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
checkpoints
:
save_res_path
:
./output/det_db/predicts_db.txt
save_inference_dir
:
...
...
ppocr/data/det/dataset_traversal.py
浏览文件 @
81d8d190
...
...
@@ -89,13 +89,13 @@ class EvalTestReader(object):
def
batch_iter_reader
():
batch_outs
=
[]
for
img_path
,
img_name
in
img_list
:
for
img_path
in
img_list
:
img
=
cv2
.
imread
(
img_path
)
if
img
is
None
:
logger
.
info
(
"load image error:"
+
img_path
)
continue
outs
=
process_function
(
img
)
outs
.
append
(
img_
name
)
outs
.
append
(
img_
path
)
batch_outs
.
append
(
outs
)
if
len
(
batch_outs
)
==
batch_size
:
yield
batch_outs
...
...
tools/infer/predict_det.py
浏览文件 @
81d8d190
...
...
@@ -20,11 +20,14 @@ from ppocr.data.det.east_process import EASTProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
from
ppocr.postprocess.db_postprocess
import
DBPostProcess
from
ppocr.postprocess.east_postprocess
import
EASTPostPocess
from
ppocr.utils.utility
import
get_image_file_list
from
tools.infer.utility
import
draw_ocr
import
copy
import
numpy
as
np
import
math
import
time
import
sys
import
os
class
TextDetector
(
object
):
...
...
@@ -152,7 +155,7 @@ class TextDetector(object):
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
image_file_list
=
utility
.
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_detector
=
TextDetector
(
args
)
count
=
0
total_time
=
0
...
...
@@ -166,5 +169,14 @@ if __name__ == "__main__":
total_time
+=
elapse
count
+=
1
print
(
"Predict time of %s:"
%
image_file
,
elapse
)
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
draw_img
=
draw_ocr
(
img
,
dt_boxes
,
None
,
None
,
False
)
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
print
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
print
(
"Avg Time:"
,
total_time
/
(
count
-
1
))
tools/infer/utility.py
浏览文件 @
81d8d190
...
...
@@ -127,10 +127,10 @@ def resize_img(img, input_size=600):
def
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
draw_txt
=
True
,
drop_score
=
0.5
):
from
PIL
import
Image
,
ImageDraw
,
ImageFont
w
,
h
=
image
.
size
img
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img
)
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
for
(
box
,
score
)
in
zip
(
boxes
,
scores
):
if
score
<
drop_score
:
continue
...
...
tools/infer_det.py
浏览文件 @
81d8d190
...
...
@@ -40,7 +40,7 @@ set_paddle_flags(
)
from
paddle
import
fluid
from
ppocr.utils.utility
import
create_module
from
ppocr.utils.utility
import
create_module
,
get_image_file_list
import
program
from
ppocr.utils.save_load
import
init_model
from
ppocr.data.reader_main
import
reader_main
...
...
@@ -50,20 +50,18 @@ from ppocr.utils.utility import initial_logger
logger
=
initial_logger
()
def
draw_det_res
(
dt_boxes
,
config
,
img
_name
,
ino
):
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
):
if
len
(
dt_boxes
)
>
0
:
img_set_path
=
config
[
'TestReader'
][
'img_set_dir'
]
img_path
=
img_set_path
+
img_name
import
cv2
src_im
=
cv2
.
imread
(
img_path
)
src_im
=
img
for
box
in
dt_boxes
:
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
save_det_path
=
os
.
path
.
base
name
(
config
[
'Global'
][
save_det_path
=
os
.
path
.
dir
name
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
if
not
os
.
path
.
exists
(
save_det_path
):
os
.
makedirs
(
save_det_path
)
save_path
=
os
.
path
.
join
(
save_det_path
,
"det_{}.jpg"
.
format
(
img_name
))
save_path
=
os
.
path
.
join
(
save_det_path
,
os
.
path
.
basename
(
img_name
))
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
...
...
@@ -103,8 +101,12 @@ def main():
raise
Exception
(
"{} not exists!"
.
format
(
checkpoints
))
save_res_path
=
config
[
'Global'
][
'save_res_path'
]
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
save_res_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
save_res_path
))
with
open
(
save_res_path
,
"wb"
)
as
fout
:
test_reader
=
reader_main
(
config
=
config
,
mode
=
'test'
)
# image_file_list = get_image_file_list(args.image_dir)
tackling_num
=
0
for
data
in
test_reader
():
img_num
=
len
(
data
)
...
...
@@ -128,7 +130,13 @@ def main():
postprocess_params
.
update
(
global_params
)
postprocess
=
create_module
(
postprocess_params
[
'function'
])
\
(
params
=
postprocess_params
)
dt_boxes_list
=
postprocess
({
"maps"
:
outs
[
0
]},
ratio_list
)
if
config
[
'Global'
][
'algorithm'
]
==
'EAST'
:
dic
=
{
'f_score'
:
outs
[
0
],
'f_geo'
:
outs
[
1
]}
elif
config
[
'Global'
][
'algorithm'
]
==
'DB'
:
dic
=
{
'maps'
:
outs
[
0
]}
else
:
raise
Exception
(
"only support algorithm: ['EAST', 'BD']"
)
dt_boxes_list
=
postprocess
(
dic
,
ratio_list
)
for
ino
in
range
(
img_num
):
dt_boxes
=
dt_boxes_list
[
ino
]
img_name
=
img_name_list
[
ino
]
...
...
@@ -139,7 +147,8 @@ def main():
dt_boxes_json
.
append
(
tmp_json
)
otstr
=
img_name
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
draw_det_res
(
dt_boxes
,
config
,
img_name
,
ino
)
src_img
=
cv2
.
imread
(
img_name
)
draw_det_res
(
dt_boxes
,
config
,
src_img
,
img_name
)
logger
.
info
(
"success!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录