Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
1bcfd9f1
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1bcfd9f1
编写于
6月 28, 2020
作者:
X
xiaoting
提交者:
GitHub
6月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #263 from ZhangXinNan/zxdev
优化predict_rec.py
上级
9313bdfa
2eb6244c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
19 deletion
+34
-19
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+32
-19
tools/infer/predict_system.py
tools/infer/predict_system.py
+2
-0
未找到文件。
tools/infer/predict_rec.py
浏览文件 @
1bcfd9f1
...
...
@@ -13,9 +13,9 @@
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'../..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)
))
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
...
...
@@ -33,14 +33,12 @@ class TextRecognizer(object):
def
__init__
(
self
,
args
):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
rec_image_shape
=
image_shape
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
char_ops_params
=
{}
char_ops_params
[
"character_type"
]
=
args
.
rec_char_type
char_ops_params
[
"character_dict_path"
]
=
args
.
rec_char_dict_path
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
}
if
self
.
rec_algorithm
!=
"RARE"
:
char_ops_params
[
'loss_type'
]
=
'ctc'
self
.
loss_type
=
'ctc'
...
...
@@ -51,16 +49,16 @@ class TextRecognizer(object):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
assert
imgC
==
img
.
shape
[
2
]
if
self
.
character_type
==
"ch"
:
imgW
=
int
(
32
*
max_wh_ratio
)
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
imgW
=
int
(
math
.
ceil
(
32
*
max_wh_ratio
))
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
)
,
interpolation
=
cv2
.
INTER_CUBIC
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
...
...
@@ -71,7 +69,15 @@ class TextRecognizer(object):
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
rec_res
=
[]
# Calculate the aspect ratio of all text bars
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
# rec_res = []
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
predict_time
=
0
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
...
...
@@ -79,11 +85,13 @@ class TextRecognizer(object):
norm_img_batch
=
[]
max_wh_ratio
=
0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
ino
].
shape
[
0
:
2
]
# h, w = img_list[ino].shape[0:2]
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
norm_img
=
self
.
resize_norm_img
(
img_list
[
ino
],
max_wh_ratio
)
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
...
...
@@ -111,7 +119,8 @@ class TextRecognizer(object):
blank
=
probs
.
shape
[
1
]
valid_ind
=
np
.
where
(
ind
!=
(
blank
-
1
))[
0
]
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
rec_res
.
append
([
preds_text
,
score
])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
else
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
predict_batch
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
...
...
@@ -126,19 +135,19 @@ class TextRecognizer(object):
preds
=
rec_idx_batch
[
rno
,
1
:
end_pos
[
1
]]
score
=
np
.
mean
(
predict_batch
[
rno
,
1
:
end_pos
[
1
]])
preds_text
=
self
.
char_ops
.
decode
(
preds
)
rec_res
.
append
([
preds_text
,
score
])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
return
rec_res
,
predict_time
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
img_list
=
[]
for
image_file
in
image_file_list
:
img
=
cv2
.
imread
(
image_file
)
img
=
cv2
.
imread
(
image_file
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
...
...
@@ -159,3 +168,7 @@ if __name__ == "__main__":
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for %d images:%.3f"
%
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
tools/infer/predict_system.py
浏览文件 @
1bcfd9f1
...
...
@@ -75,6 +75,7 @@ class TextSystem(object):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
print
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
...
...
@@ -86,6 +87,7 @@ class TextSystem(object):
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
print
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return
dt_boxes
,
rec_res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录