Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
3667498b
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3667498b
编写于
7月 16, 2020
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cv exception
上级
02519d50
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
28 addition
and
25 deletion
+28
-25
deploy/cpp_infer/src/postprocess_op.cpp
deploy/cpp_infer/src/postprocess_op.cpp
+1
-1
tools/eval.py
tools/eval.py
+2
-2
tools/eval_utils/eval_rec_utils.py
tools/eval_utils/eval_rec_utils.py
+1
-0
tools/infer_det.py
tools/infer_det.py
+1
-1
tools/infer_rec.py
tools/infer_rec.py
+4
-4
tools/test_hubserving.py
tools/test_hubserving.py
+17
-15
tools/train.py
tools/train.py
+2
-2
未找到文件。
deploy/cpp_infer/src/postprocess_op.cpp
浏览文件 @
3667498b
...
@@ -219,7 +219,7 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
...
@@ -219,7 +219,7 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
boxes
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
boxes
;
for
(
int
_i
=
0
;
_i
<
num_contours
;
_i
++
)
{
for
(
int
_i
=
0
;
_i
<
num_contours
;
_i
++
)
{
if
(
contours
[
_i
].
size
()
<=
0
)
{
if
(
contours
[
_i
].
size
()
<=
2
)
{
continue
;
continue
;
}
}
float
ssid
;
float
ssid
;
...
...
tools/eval.py
浏览文件 @
3667498b
...
@@ -82,7 +82,7 @@ def main():
...
@@ -82,7 +82,7 @@ def main():
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_varname_list'
:
eval_fetch_varname_list
}
'fetch_varname_list'
:
eval_fetch_varname_list
}
metrics
=
eval_det_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
metrics
=
eval_det_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
print
(
"Eval result"
,
metrics
)
logger
.
info
(
"Eval result: {}"
.
format
(
metrics
)
)
else
:
else
:
reader_type
=
config
[
'Global'
][
'reader_yml'
]
reader_type
=
config
[
'Global'
][
'reader_yml'
]
if
"benchmark"
not
in
reader_type
:
if
"benchmark"
not
in
reader_type
:
...
@@ -92,7 +92,7 @@ def main():
...
@@ -92,7 +92,7 @@ def main():
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_varname_list'
:
eval_fetch_varname_list
}
'fetch_varname_list'
:
eval_fetch_varname_list
}
metrics
=
eval_rec_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
metrics
=
eval_rec_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
print
(
"Eval result:"
,
metrics
)
logger
.
info
(
"Eval result: {}"
.
format
(
metrics
)
)
else
:
else
:
eval_info_dict
=
{
'program'
:
eval_program
,
\
eval_info_dict
=
{
'program'
:
eval_program
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
...
...
tools/eval_utils/eval_rec_utils.py
浏览文件 @
3667498b
...
@@ -75,6 +75,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
...
@@ -75,6 +75,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
total_acc_num
+=
acc_num
total_acc_num
+=
acc_num
total_sample_num
+=
sample_num
total_sample_num
+=
sample_num
logger
.
info
(
"eval batch id: {}, acc: {}"
.
format
(
total_batch_num
,
acc
))
total_batch_num
+=
1
total_batch_num
+=
1
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
...
...
tools/infer_det.py
浏览文件 @
3667498b
...
@@ -70,7 +70,7 @@ def draw_det_res(dt_boxes, config, img, img_name):
...
@@ -70,7 +70,7 @@ def draw_det_res(dt_boxes, config, img, img_name):
def
main
():
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
program
.
merge_config
(
FLAGS
.
opt
)
print
(
config
)
logger
.
info
(
config
)
# check if set use_gpu=True in paddlepaddle cpu version
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
...
...
tools/infer_rec.py
浏览文件 @
3667498b
...
@@ -84,7 +84,7 @@ def main():
...
@@ -84,7 +84,7 @@ def main():
if
len
(
infer_list
)
==
0
:
if
len
(
infer_list
)
==
0
:
logger
.
info
(
"Can not find img in infer_img dir."
)
logger
.
info
(
"Can not find img in infer_img dir."
)
for
i
in
range
(
max_img_num
):
for
i
in
range
(
max_img_num
):
print
(
"infer_img:%s"
%
infer_list
[
i
])
logger
.
info
(
"infer_img:%s"
%
infer_list
[
i
])
img
=
next
(
blobs
)
img
=
next
(
blobs
)
predict
=
exe
.
run
(
program
=
eval_prog
,
predict
=
exe
.
run
(
program
=
eval_prog
,
feed
=
{
"image"
:
img
},
feed
=
{
"image"
:
img
},
...
@@ -115,9 +115,9 @@ def main():
...
@@ -115,9 +115,9 @@ def main():
preds
=
preds
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds
)
preds_text
=
char_ops
.
decode
(
preds
)
print
(
"
\t
index:"
,
preds
)
logger
.
info
(
"
\t
index: {}"
.
format
(
preds
)
)
print
(
"
\t
word :"
,
preds_text
)
logger
.
info
(
"
\t
word : {}"
.
format
(
preds_text
)
)
print
(
"
\t
score :"
,
score
)
logger
.
info
(
"
\t
score: {}"
.
format
(
score
)
)
# save for inference model
# save for inference model
target_var
=
[]
target_var
=
[]
...
...
tools/test_hubserving.py
浏览文件 @
3667498b
...
@@ -41,19 +41,19 @@ def draw_server_result(image_file, res):
...
@@ -41,19 +41,19 @@ def draw_server_result(image_file, res):
if
len
(
res
)
==
0
:
if
len
(
res
)
==
0
:
return
np
.
array
(
image
)
return
np
.
array
(
image
)
keys
=
res
[
0
].
keys
()
keys
=
res
[
0
].
keys
()
if
'text_region'
not
in
keys
:
# for ocr_rec, draw function is invalid
if
'text_region'
not
in
keys
:
# for ocr_rec, draw function is invalid
print
(
"draw function is invalid for ocr_rec!"
)
logger
.
info
(
"draw function is invalid for ocr_rec!"
)
return
None
return
None
elif
'text'
not
in
keys
:
# for ocr_det
elif
'text'
not
in
keys
:
# for ocr_det
print
(
"draw text boxes only!"
)
logger
.
info
(
"draw text boxes only!"
)
boxes
=
[]
boxes
=
[]
for
dno
in
range
(
len
(
res
)):
for
dno
in
range
(
len
(
res
)):
boxes
.
append
(
res
[
dno
][
'text_region'
])
boxes
.
append
(
res
[
dno
][
'text_region'
])
boxes
=
np
.
array
(
boxes
)
boxes
=
np
.
array
(
boxes
)
draw_img
=
draw_boxes
(
image
,
boxes
)
draw_img
=
draw_boxes
(
image
,
boxes
)
return
draw_img
return
draw_img
else
:
# for ocr_system
else
:
# for ocr_system
print
(
"draw boxes and texts!"
)
logger
.
info
(
"draw boxes and texts!"
)
boxes
=
[]
boxes
=
[]
texts
=
[]
texts
=
[]
scores
=
[]
scores
=
[]
...
@@ -63,7 +63,8 @@ def draw_server_result(image_file, res):
...
@@ -63,7 +63,8 @@ def draw_server_result(image_file, res):
scores
.
append
(
res
[
dno
][
'confidence'
])
scores
.
append
(
res
[
dno
][
'confidence'
])
boxes
=
np
.
array
(
boxes
)
boxes
=
np
.
array
(
boxes
)
scores
=
np
.
array
(
scores
)
scores
=
np
.
array
(
scores
)
draw_img
=
draw_ocr
(
image
,
boxes
,
texts
,
scores
,
draw_txt
=
True
,
drop_score
=
0.5
)
draw_img
=
draw_ocr
(
image
,
boxes
,
texts
,
scores
,
draw_txt
=
True
,
drop_score
=
0.5
)
return
draw_img
return
draw_img
...
@@ -81,13 +82,13 @@ def main(url, image_path):
...
@@ -81,13 +82,13 @@ def main(url, image_path):
# 发送HTTP请求
# 发送HTTP请求
starttime
=
time
.
time
()
starttime
=
time
.
time
()
data
=
{
'images'
:[
cv2_to_base64
(
img
)]}
data
=
{
'images'
:
[
cv2_to_base64
(
img
)]}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
total_time
+=
elapse
total_time
+=
elapse
print
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
res
=
r
.
json
()[
"results"
][
0
]
res
=
r
.
json
()[
"results"
][
0
]
print
(
res
)
logger
.
info
(
res
)
if
is_visualize
:
if
is_visualize
:
draw_img
=
draw_server_result
(
image_file
,
res
)
draw_img
=
draw_server_result
(
image_file
,
res
)
...
@@ -98,16 +99,17 @@ def main(url, image_path):
...
@@ -98,16 +99,17 @@ def main(url, image_path):
cv2
.
imwrite
(
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
draw_img
[:,
:,
::
-
1
])
print
(
"The visualized image saved in {}"
.
format
(
logger
.
info
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
cnt
+=
1
cnt
+=
1
if
cnt
%
100
==
0
:
if
cnt
%
100
==
0
:
print
(
cnt
,
"processed"
)
logger
.
info
(
"{} processed"
.
format
(
cnt
)
)
print
(
"avg time cost: "
,
float
(
total_time
)
/
cnt
)
logger
.
info
(
"avg time cost: {}"
.
format
(
float
(
total_time
)
/
cnt
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
!=
3
:
if
len
(
sys
.
argv
)
!=
3
:
print
(
"Usage: %s server_url image_path"
%
sys
.
argv
[
0
])
logger
.
info
(
"Usage: %s server_url image_path"
%
sys
.
argv
[
0
])
else
:
else
:
server_url
=
sys
.
argv
[
1
]
server_url
=
sys
.
argv
[
1
]
image_path
=
sys
.
argv
[
2
]
image_path
=
sys
.
argv
[
2
]
...
...
tools/train.py
浏览文件 @
3667498b
...
@@ -118,7 +118,7 @@ def main():
...
@@ -118,7 +118,7 @@ def main():
def
test_reader
():
def
test_reader
():
config
=
program
.
load_config
(
FLAGS
.
config
)
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
program
.
merge_config
(
FLAGS
.
opt
)
print
(
config
)
logger
.
info
(
config
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
import
time
import
time
starttime
=
time
.
time
()
starttime
=
time
.
time
()
...
@@ -129,7 +129,7 @@ def test_reader():
...
@@ -129,7 +129,7 @@ def test_reader():
if
count
%
1
==
0
:
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
starttime
=
time
.
time
()
print
(
"reader:"
,
count
,
len
(
data
),
batch_time
)
logger
.
info
(
"reader:"
,
count
,
len
(
data
),
batch_time
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
info
(
e
)
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录