Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
fffd556c
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fffd556c
编写于
9月 03, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix distill model predict
上级
5947c567
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
17 deletion
+38
-17
tools/infer_det.py
tools/infer_det.py
+38
-17
未找到文件。
tools/infer_det.py
浏览文件 @
fffd556c
...
...
@@ -34,23 +34,21 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
):
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
,
save_path
):
if
len
(
dt_boxes
)
>
0
:
import
cv2
src_im
=
img
for
box
in
dt_boxes
:
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
if
not
os
.
path
.
exists
(
save_det_path
):
os
.
makedirs
(
save_det_path
)
save_path
=
os
.
path
.
join
(
save_det_path
,
os
.
path
.
basename
(
img_name
))
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
save_path
=
os
.
path
.
join
(
save_path
,
os
.
path
.
basename
(
img_name
))
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
...
...
@@ -61,8 +59,7 @@ def main():
# build model
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
)
_
=
load_dygraph_params
(
config
,
model
,
logger
,
None
)
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
...
@@ -96,17 +93,41 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
boxes
=
post_result
[
0
][
'points'
]
# write result
src_img
=
cv2
.
imread
(
file
)
dt_boxes_json
=
[]
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
# parser boxes if post_result is dict
if
isinstance
(
post_result
,
dict
):
det_box_json
=
{}
for
k
in
post_result
.
keys
():
boxes
=
post_result
[
k
][
0
][
'points'
]
dt_boxes_list
=
[]
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_list
.
append
(
tmp_json
)
det_box_json
[
k
]
=
dt_boxes_list
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results_{}/"
.
format
(
k
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
else
:
boxes
=
post_result
[
0
][
'points'
]
dt_boxes_json
=
[]
# write result
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
src_img
=
cv2
.
imread
(
file
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
logger
.
info
(
"success!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录