Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
aa48cda3
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1534
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aa48cda3
编写于
4年前
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
save for tensorrt
上级
5fb3c419
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
77 addition
and
56 deletion
+77
-56
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+5
-10
tools/infer/predict_system.py
tools/infer/predict_system.py
+63
-45
tools/infer/utility.py
tools/infer/utility.py
+9
-1
未找到文件。
tools/infer/predict_rec.py
浏览文件 @
aa48cda3
...
...
@@ -62,8 +62,8 @@ class TextRecognizer(object):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
assert
imgC
==
img
.
shape
[
2
]
if
self
.
character_type
==
"ch"
:
imgW
=
int
((
32
*
max_wh_ratio
))
#
if self.character_type == "ch":
#
imgW = int((32 * max_wh_ratio))
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
...
...
@@ -314,17 +314,12 @@ def main(args):
valid_image_file_list
.
append
(
image_file
)
img_list
.
append
(
img
)
try
:
rec_res
,
predict_time
=
text_recognizer
(
img_list
)
rec_res
,
predict_time
=
text_recognizer
(
img_list
)
"""
except Exception as e:
print(e)
logger
.
info
(
"ERROR!!!!
\n
"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq
\n
"
"If your model has tps module: "
"TPS does not support variable shape.
\n
"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit()
"""
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for %d images:%.3f"
%
...
...
This diff is collapsed.
Click to expand it.
tools/infer/predict_system.py
浏览文件 @
aa48cda3
...
...
@@ -123,50 +123,68 @@ def main(args):
text_sys
=
TextSystem
(
args
)
is_visualize
=
True
tackle_img_num
=
0
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
tackle_img_num
+=
1
if
not
args
.
use_gpu
and
args
.
enable_mkldnn
and
tackle_img_num
%
30
==
0
:
text_sys
=
TextSystem
(
args
)
dt_boxes
,
rec_res
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
print
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
drop_score
=
0.5
dt_num
=
len
(
dt_boxes
)
for
dno
in
range
(
dt_num
):
text
,
score
=
rec_res
[
dno
]
if
score
>=
drop_score
:
text_str
=
"%s, %.3f"
%
(
text
,
score
)
print
(
text_str
)
if
is_visualize
:
image
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
boxes
=
dt_boxes
txts
=
[
rec_res
[
i
][
0
]
for
i
in
range
(
len
(
rec_res
))]
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
draw_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
)
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
print
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
if
not
args
.
enable_benchmark
:
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
tackle_img_num
+=
1
if
not
args
.
use_gpu
and
args
.
enable_mkldnn
and
tackle_img_num
%
30
==
0
:
text_sys
=
TextSystem
(
args
)
dt_boxes
,
rec_res
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
print
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
drop_score
=
0.5
dt_num
=
len
(
dt_boxes
)
for
dno
in
range
(
dt_num
):
text
,
score
=
rec_res
[
dno
]
if
score
>=
drop_score
:
text_str
=
"%s, %.3f"
%
(
text
,
score
)
print
(
text_str
)
if
is_visualize
:
image
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
boxes
=
dt_boxes
txts
=
[
rec_res
[
i
][
0
]
for
i
in
range
(
len
(
rec_res
))]
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
draw_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
)
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
print
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
else
:
test_num
=
10
test_time
=
0.0
for
i
in
range
(
0
,
test_num
+
10
):
#inputs = np.random.rand(640, 640, 3).astype(np.float32)
#print(image_file_list)
image_file
=
image_file_list
[
0
]
inputs
=
cv2
.
imread
(
image_file
)
inputs
=
cv2
.
resize
(
inputs
,
(
int
(
640
),
int
(
640
)))
start_time
=
time
.
time
()
dt_boxes
,
rec_res
=
text_sys
(
inputs
)
if
i
>=
10
:
test_time
+=
time
.
time
()
-
start_time
time
.
sleep
(
0.01
)
fp_message
=
"FP16"
if
args
.
use_fp16
else
"FP32"
trt_msg
=
"using tensorrt"
if
args
.
use_tensorrt
else
"not using tensorrt"
print
(
"model
\t
{0}
\t
{1}
\t
batch size: {2}
\t
time(ms): {3}"
.
format
(
trt_msg
,
fp_message
,
args
.
max_batch_size
,
1000
*
test_time
/
test_num
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
This diff is collapsed.
Click to expand it.
tools/infer/utility.py
浏览文件 @
aa48cda3
...
...
@@ -36,7 +36,9 @@ def parse_args():
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--use_fp16"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--enable_benchmark"
,
type
=
str2bool
,
default
=
True
)
#params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
...
...
@@ -112,6 +114,12 @@ def create_predictor(args, mode):
else
:
config
.
switch_use_feed_fetch_ops
(
True
)
if
args
.
use_tensorrt
:
config
.
enable_tensorrt_engine
(
precision_mode
=
AnalysisConfig
.
Precision
.
Half
if
args
.
use_fp16
else
AnalysisConfig
.
Precision
.
Float32
,
max_batch_size
=
args
.
max_batch_size
)
predictor
=
create_paddle_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
for
name
in
input_names
:
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部