Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
9717944c
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9717944c
编写于
6月 23, 2020
作者:
张欣-男
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
识别文本框时,对文本框按宽高比进行排序。
上级
4ca78a07
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
33 addition
and
23 deletion
+33
-23
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+31
-23
tools/infer/predict_system.py
tools/infer/predict_system.py
+2
-0
未找到文件。
tools/infer/predict_rec.py
浏览文件 @
9717944c
...
...
@@ -13,9 +13,9 @@
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'../..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)
))
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
...
...
@@ -33,14 +33,12 @@ class TextRecognizer(object):
def
__init__
(
self
,
args
):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
rec_image_shape
=
image_shape
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
char_ops_params
=
{}
char_ops_params
[
"character_type"
]
=
args
.
rec_char_type
char_ops_params
[
"character_dict_path"
]
=
args
.
rec_char_dict_path
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
}
if
self
.
rec_algorithm
!=
"RARE"
:
char_ops_params
[
'loss_type'
]
=
'ctc'
self
.
loss_type
=
'ctc'
...
...
@@ -51,16 +49,11 @@ class TextRecognizer(object):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
character_type
==
"ch"
:
imgW
=
int
(
32
*
max_wh_ratio
)
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
(
math
.
ceil
(
32
*
max_wh_ratio
))
h
,
w
=
img
.
shape
[:
2
]
resized_w
=
int
(
math
.
ceil
(
imgH
*
w
/
float
(
h
)))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
),
interpolation
=
cv2
.
INTER_CUBIC
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
...
...
@@ -71,7 +64,15 @@ class TextRecognizer(object):
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
rec_res
=
[]
# 统计所有文本条的宽高比
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# 对于文本框比较多且长短差异较大的情况下,通过排序再组合batch可以明显加速识别
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
# rec_res = []
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
predict_time
=
0
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
...
...
@@ -80,10 +81,12 @@ class TextRecognizer(object):
max_wh_ratio
=
0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
ino
].
shape
[
0
:
2
]
# h, w = img_list[indices[ino]].shape[0:2]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
norm_img
=
self
.
resize_norm_img
(
img_list
[
ino
],
max_wh_ratio
)
# norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
...
...
@@ -111,7 +114,8 @@ class TextRecognizer(object):
blank
=
probs
.
shape
[
1
]
valid_ind
=
np
.
where
(
ind
!=
(
blank
-
1
))[
0
]
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
rec_res
.
append
([
preds_text
,
score
])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
else
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
predict_batch
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
...
...
@@ -126,19 +130,19 @@ class TextRecognizer(object):
preds
=
rec_idx_batch
[
rno
,
1
:
end_pos
[
1
]]
score
=
np
.
mean
(
predict_batch
[
rno
,
1
:
end_pos
[
1
]])
preds_text
=
self
.
char_ops
.
decode
(
preds
)
rec_res
.
append
([
preds_text
,
score
])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
return
rec_res
,
predict_time
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
img_list
=
[]
for
image_file
in
image_file_list
:
img
=
cv2
.
imread
(
image_file
)
img
=
cv2
.
imread
(
image_file
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
...
...
@@ -159,3 +163,7 @@ if __name__ == "__main__":
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for %d images:%.3f"
%
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
tools/infer/predict_system.py
浏览文件 @
9717944c
...
...
@@ -75,6 +75,7 @@ class TextSystem(object):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
print
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
...
...
@@ -86,6 +87,7 @@ class TextSystem(object):
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
print
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return
dt_boxes
,
rec_res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录