Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
8f5e5177
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8f5e5177
编写于
5月 11, 2022
作者:
C
chenjian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
上级
9b3119df
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
36 addition
and
37 deletion
+36
-37
modules/image/text_recognition/ppocrv3_det_ch/module.py
modules/image/text_recognition/ppocrv3_det_ch/module.py
+11
-10
modules/image/text_recognition/ppocrv3_det_ch/processor.py
modules/image/text_recognition/ppocrv3_det_ch/processor.py
+7
-8
modules/image/text_recognition/ppocrv3_rec_ch/character.py
modules/image/text_recognition/ppocrv3_rec_ch/character.py
+1
-6
modules/image/text_recognition/ppocrv3_rec_ch/module.py
modules/image/text_recognition/ppocrv3_rec_ch/module.py
+17
-13
未找到文件。
modules/image/text_recognition/ppocrv3_det_ch/module.py
浏览文件 @
8f5e5177
...
@@ -34,14 +34,14 @@ def base64_to_cv2(b64str):
...
@@ -34,14 +34,14 @@ def base64_to_cv2(b64str):
@
moduleinfo
(
@
moduleinfo
(
name
=
"
ppocrv3_det_ch
"
,
name
=
"
ch_pp-ocrv3_det
"
,
version
=
"1.0.0"
,
version
=
"1.0.0"
,
summary
=
summary
=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm."
,
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm."
,
author
=
"paddle-dev"
,
author
=
"paddle-dev"
,
author_email
=
"paddle-dev@baidu.com"
,
author_email
=
"paddle-dev@baidu.com"
,
type
=
"cv/text_recognition"
)
type
=
"cv/text_recognition"
)
class
Ch
ineseTextDetectionDB
(
hub
.
Module
):
class
Ch
PPOCRv3Det
(
hub
.
Module
):
def
_initialize
(
self
,
enable_mkldnn
=
False
):
def
_initialize
(
self
,
enable_mkldnn
=
False
):
"""
"""
...
@@ -155,7 +155,8 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -155,7 +155,8 @@ class ChineseTextDetectionDB(hub.Module):
use_gpu
=
False
,
use_gpu
=
False
,
output_dir
=
'detection_result'
,
output_dir
=
'detection_result'
,
visualization
=
False
,
visualization
=
False
,
box_thresh
=
0.5
):
box_thresh
=
0.5
,
det_db_unclip_ratio
=
1.5
):
"""
"""
Get the text box in the predicted images.
Get the text box in the predicted images.
Args:
Args:
...
@@ -165,6 +166,7 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -165,6 +166,7 @@ class ChineseTextDetectionDB(hub.Module):
output_dir (str): The directory to store output images.
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
box_thresh(float): the threshold of the detected text box's confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns:
Returns:
res (list): The result of text detection box and save path of images.
res (list): The result of text detection box and save path of images.
"""
"""
...
@@ -195,7 +197,7 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -195,7 +197,7 @@ class ChineseTextDetectionDB(hub.Module):
'thresh'
:
0.3
,
'thresh'
:
0.3
,
'box_thresh'
:
0.6
,
'box_thresh'
:
0.6
,
'max_candidates'
:
1000
,
'max_candidates'
:
1000
,
'unclip_ratio'
:
1.5
'unclip_ratio'
:
det_db_unclip_ratio
})
})
all_imgs
=
[]
all_imgs
=
[]
...
@@ -204,7 +206,6 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -204,7 +206,6 @@ class ChineseTextDetectionDB(hub.Module):
for
original_image
in
predicted_data
:
for
original_image
in
predicted_data
:
ori_im
=
original_image
.
copy
()
ori_im
=
original_image
.
copy
()
im
,
ratio_list
=
preprocessor
(
original_image
)
im
,
ratio_list
=
preprocessor
(
original_image
)
print
(
'after preprocess int det, shape{}'
.
format
(
im
.
shape
))
res
=
{
'save_path'
:
''
}
res
=
{
'save_path'
:
''
}
if
im
is
None
:
if
im
is
None
:
res
[
'data'
]
=
[]
res
[
'data'
]
=
[]
...
@@ -222,15 +223,10 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -222,15 +223,10 @@ class ChineseTextDetectionDB(hub.Module):
outs_dict
=
{}
outs_dict
=
{}
outs_dict
[
'maps'
]
=
outputs
[
0
]
outs_dict
[
'maps'
]
=
outputs
[
0
]
# data_out = self.output_tensors[0].copy_to_cpu()
print
(
'Outputs[0] in det, shape: {}'
.
format
(
outputs
[
0
].
shape
))
dt_boxes_list
=
postprocessor
(
outs_dict
,
[
ratio_list
])
dt_boxes_list
=
postprocessor
(
outs_dict
,
[
ratio_list
])
dt_boxes
=
dt_boxes_list
[
0
]
dt_boxes
=
dt_boxes_list
[
0
]
print
(
'after postprocess int det, shape{}'
.
format
(
dt_boxes
.
shape
))
boxes
=
self
.
filter_tag_det_res
(
dt_boxes_list
[
0
],
original_image
.
shape
)
boxes
=
self
.
filter_tag_det_res
(
dt_boxes_list
[
0
],
original_image
.
shape
)
print
(
'after fitler tag int det, shape{}'
.
format
(
boxes
.
shape
))
res
[
'data'
]
=
boxes
.
astype
(
np
.
int
).
tolist
()
res
[
'data'
]
=
boxes
.
astype
(
np
.
int
).
tolist
()
print
(
'boxes: {}'
.
format
(
boxes
))
all_imgs
.
append
(
im
)
all_imgs
.
append
(
im
)
all_ratios
.
append
(
ratio_list
)
all_ratios
.
append
(
ratio_list
)
if
visualization
:
if
visualization
:
...
@@ -278,6 +274,7 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -278,6 +274,7 @@ class ChineseTextDetectionDB(hub.Module):
results
=
self
.
detect_text
(
paths
=
[
args
.
input_path
],
results
=
self
.
detect_text
(
paths
=
[
args
.
input_path
],
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
output_dir
=
args
.
output_dir
,
det_db_unclip_ratio
=
args
.
det_db_unclip_ratio
,
visualization
=
args
.
visualization
)
visualization
=
args
.
visualization
)
return
results
return
results
...
@@ -297,6 +294,10 @@ class ChineseTextDetectionDB(hub.Module):
...
@@ -297,6 +294,10 @@ class ChineseTextDetectionDB(hub.Module):
type
=
ast
.
literal_eval
,
type
=
ast
.
literal_eval
,
default
=
False
,
default
=
False
,
help
=
"whether to save output as images."
)
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--det_db_unclip_ratio'
,
type
=
float
,
default
=
1.5
,
help
=
"unclip ratio for post processing in DB detection."
)
def
add_module_input_arg
(
self
):
def
add_module_input_arg
(
self
):
"""
"""
...
...
modules/image/text_recognition/ppocrv3_det_ch/processor.py
浏览文件 @
8f5e5177
...
@@ -25,7 +25,6 @@ class DBProcessTest(object):
...
@@ -25,7 +25,6 @@ class DBProcessTest(object):
self
.
resize_type
=
0
self
.
resize_type
=
0
if
'test_image_shape'
in
params
:
if
'test_image_shape'
in
params
:
self
.
image_shape
=
params
[
'test_image_shape'
]
self
.
image_shape
=
params
[
'test_image_shape'
]
# print(self.image_shape)
self
.
resize_type
=
1
self
.
resize_type
=
1
if
'max_side_len'
in
params
:
if
'max_side_len'
in
params
:
self
.
max_side_len
=
params
[
'max_side_len'
]
self
.
max_side_len
=
params
[
'max_side_len'
]
...
@@ -54,15 +53,14 @@ class DBProcessTest(object):
...
@@ -54,15 +53,14 @@ class DBProcessTest(object):
resize_h
=
int
(
h
*
ratio
)
resize_h
=
int
(
h
*
ratio
)
resize_w
=
int
(
w
*
ratio
)
resize_w
=
int
(
w
*
ratio
)
resize_h
=
int
(
round
(
resize_h
/
32
)
*
32
)
resize_h
=
max
(
int
(
round
(
resize_h
/
32
)
*
32
),
32
)
resize_w
=
int
(
round
(
resize_w
/
32
)
*
32
)
resize_w
=
max
(
int
(
round
(
resize_w
/
32
)
*
32
),
32
)
try
:
try
:
if
int
(
resize_w
)
<=
0
or
int
(
resize_h
)
<=
0
:
if
int
(
resize_w
)
<=
0
or
int
(
resize_h
)
<=
0
:
return
None
,
(
None
,
None
)
return
None
,
(
None
,
None
)
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
except
:
except
:
print
(
img
.
shape
,
resize_w
,
resize_h
)
sys
.
exit
(
0
)
sys
.
exit
(
0
)
ratio_h
=
resize_h
/
float
(
h
)
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
ratio_w
=
resize_w
/
float
(
w
)
...
@@ -93,13 +91,14 @@ class DBProcessTest(object):
...
@@ -93,13 +91,14 @@ class DBProcessTest(object):
return
im
return
im
def
__call__
(
self
,
im
):
def
__call__
(
self
,
im
):
src_h
,
src_w
,
_
=
im
.
shape
if
self
.
resize_type
==
0
:
if
self
.
resize_type
==
0
:
im
,
(
ratio_h
,
ratio_w
)
=
self
.
resize_image_type0
(
im
)
im
,
(
ratio_h
,
ratio_w
)
=
self
.
resize_image_type0
(
im
)
else
:
else
:
im
,
(
ratio_h
,
ratio_w
)
=
self
.
resize_image_type1
(
im
)
im
,
(
ratio_h
,
ratio_w
)
=
self
.
resize_image_type1
(
im
)
im
=
self
.
normalize
(
im
)
im
=
self
.
normalize
(
im
)
im
=
im
[
np
.
newaxis
,
:]
im
=
im
[
np
.
newaxis
,
:]
return
[
im
,
(
ratio_h
,
ratio_w
)]
return
[
im
,
(
src_h
,
src_w
,
ratio_h
,
ratio_w
)]
class
DBPostProcess
(
object
):
class
DBPostProcess
(
object
):
...
@@ -228,7 +227,7 @@ class DBPostProcess(object):
...
@@ -228,7 +227,7 @@ class DBPostProcess(object):
cv2
.
fillPoly
(
mask
,
contour
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
cv2
.
fillPoly
(
mask
,
contour
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
__call__
(
self
,
outs_dict
,
ratio
_list
):
def
__call__
(
self
,
outs_dict
,
shape
_list
):
pred
=
outs_dict
[
'maps'
]
pred
=
outs_dict
[
'maps'
]
pred
=
pred
[:,
0
,
:,
:]
pred
=
pred
[:,
0
,
:,
:]
...
@@ -236,10 +235,10 @@ class DBPostProcess(object):
...
@@ -236,10 +235,10 @@ class DBPostProcess(object):
boxes_batch
=
[]
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
for
batch_index
in
range
(
pred
.
shape
[
0
]):
height
,
width
=
pred
.
shape
[
-
2
:
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
batch_index
]
mask
=
segmentation
[
batch_index
]
mask
=
segmentation
[
batch_index
]
tmp_boxes
,
tmp_scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
width
,
height
)
tmp_boxes
,
tmp_scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
boxes_batch
.
append
(
tmp_boxes
)
boxes_batch
.
append
(
tmp_boxes
)
return
boxes_batch
return
boxes_batch
...
...
modules/image/text_recognition/ppocrv3_rec_ch/character.py
浏览文件 @
8f5e5177
...
@@ -59,6 +59,7 @@ class CharacterOps(object):
...
@@ -59,6 +59,7 @@ class CharacterOps(object):
self
.
character
=
dict_character
self
.
character
=
dict_character
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
return
dict_character
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
...
@@ -93,12 +94,6 @@ class CharacterOps(object):
...
@@ -93,12 +94,6 @@ class CharacterOps(object):
selection
[
1
:]
=
text_index
[
batch_idx
][
1
:]
!=
text_index
[
batch_idx
][:
-
1
]
selection
[
1
:]
=
text_index
[
batch_idx
][
1
:]
!=
text_index
[
batch_idx
][:
-
1
]
for
ignored_token
in
ignored_tokens
:
for
ignored_token
in
ignored_tokens
:
selection
&=
text_index
[
batch_idx
]
!=
ignored_token
selection
&=
text_index
[
batch_idx
]
!=
ignored_token
# print(text_index)
# print(batch_idx)
# print(selection)
# for text_id in text_index[batch_idx][selection]:
# print(text_id)
# print(self.character[text_id])
char_list
=
[
self
.
character
[
text_id
]
for
text_id
in
text_index
[
batch_idx
][
selection
]]
char_list
=
[
self
.
character
[
text_id
]
for
text_id
in
text_index
[
batch_idx
][
selection
]]
if
text_prob
is
not
None
:
if
text_prob
is
not
None
:
conf_list
=
text_prob
[
batch_idx
][
selection
]
conf_list
=
text_prob
[
batch_idx
][
selection
]
...
...
modules/image/text_recognition/ppocrv3_rec_ch/module.py
浏览文件 @
8f5e5177
...
@@ -29,14 +29,14 @@ from paddlehub.module.module import serving
...
@@ -29,14 +29,14 @@ from paddlehub.module.module import serving
@
moduleinfo
(
@
moduleinfo
(
name
=
"
ppocrv3_rec_ch
"
,
name
=
"
ch_pp-ocrv3
"
,
version
=
"1.0.0"
,
version
=
"1.0.0"
,
summary
=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions
\
summary
=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions
\
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. "
,
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. "
,
author
=
"paddle-dev"
,
author
=
"paddle-dev"
,
author_email
=
"paddle-dev@baidu.com"
,
author_email
=
"paddle-dev@baidu.com"
,
type
=
"cv/text_recognition"
)
type
=
"cv/text_recognition"
)
class
Ch
ineseOCRDBCRNN
(
hub
.
Module
):
class
Ch
PPOCRv3
(
hub
.
Module
):
def
_initialize
(
self
,
text_detector_module
=
None
,
enable_mkldnn
=
False
):
def
_initialize
(
self
,
text_detector_module
=
None
,
enable_mkldnn
=
False
):
"""
"""
...
@@ -51,7 +51,7 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -51,7 +51,7 @@ class ChineseOCRDBCRNN(hub.Module):
'use_space_char'
:
True
'use_space_char'
:
True
}
}
self
.
char_ops
=
CharacterOps
(
char_ops_params
)
self
.
char_ops
=
CharacterOps
(
char_ops_params
)
self
.
rec_image_shape
=
[
3
,
32
,
320
]
self
.
rec_image_shape
=
[
3
,
48
,
320
]
self
.
_text_detector_module
=
text_detector_module
self
.
_text_detector_module
=
text_detector_module
self
.
font_file
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'simfang.ttf'
)
self
.
font_file
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'simfang.ttf'
)
self
.
enable_mkldnn
=
enable_mkldnn
self
.
enable_mkldnn
=
enable_mkldnn
...
@@ -109,7 +109,7 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -109,7 +109,7 @@ class ChineseOCRDBCRNN(hub.Module):
text detect module
text detect module
"""
"""
if
not
self
.
_text_detector_module
:
if
not
self
.
_text_detector_module
:
self
.
_text_detector_module
=
hub
.
Module
(
name
=
'
ppocrv3_det_ch
'
,
self
.
_text_detector_module
=
hub
.
Module
(
name
=
'
ch_pp-ocrv3_det
'
,
enable_mkldnn
=
self
.
enable_mkldnn
,
enable_mkldnn
=
self
.
enable_mkldnn
,
version
=
'1.0.0'
)
version
=
'1.0.0'
)
return
self
.
_text_detector_module
return
self
.
_text_detector_module
...
@@ -152,7 +152,7 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -152,7 +152,7 @@ class ChineseOCRDBCRNN(hub.Module):
def
resize_norm_img_rec
(
self
,
img
,
max_wh_ratio
):
def
resize_norm_img_rec
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
assert
imgC
==
img
.
shape
[
2
]
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
32
*
max_wh_ratio
))
imgW
=
int
((
imgH
*
max_wh_ratio
))
h
,
w
=
img
.
shape
[:
2
]
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
...
@@ -199,7 +199,8 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -199,7 +199,8 @@ class ChineseOCRDBCRNN(hub.Module):
visualization
=
False
,
visualization
=
False
,
box_thresh
=
0.5
,
box_thresh
=
0.5
,
text_thresh
=
0.5
,
text_thresh
=
0.5
,
angle_classification_thresh
=
0.9
):
angle_classification_thresh
=
0.9
,
det_db_unclip_ratio
=
1.5
):
"""
"""
Get the chinese texts in the predicted images.
Get the chinese texts in the predicted images.
Args:
Args:
...
@@ -212,7 +213,7 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -212,7 +213,7 @@ class ChineseOCRDBCRNN(hub.Module):
box_thresh(float): the threshold of the detected text box's confidence
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns:
Returns:
res (list): The result of chinese texts and save path of images.
res (list): The result of chinese texts and save path of images.
"""
"""
...
@@ -238,10 +239,10 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -238,10 +239,10 @@ class ChineseOCRDBCRNN(hub.Module):
detection_results
=
self
.
text_detector_module
.
detect_text
(
images
=
predicted_data
,
detection_results
=
self
.
text_detector_module
.
detect_text
(
images
=
predicted_data
,
use_gpu
=
self
.
use_gpu
,
use_gpu
=
self
.
use_gpu
,
box_thresh
=
box_thresh
)
box_thresh
=
box_thresh
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
boxes
=
[
np
.
array
(
item
[
'data'
]).
astype
(
np
.
float32
)
for
item
in
detection_results
]
boxes
=
[
np
.
array
(
item
[
'data'
]).
astype
(
np
.
float32
)
for
item
in
detection_results
]
print
(
"dt_boxes num : {}"
.
format
(
len
(
boxes
[
0
])))
all_results
=
[]
all_results
=
[]
for
index
,
img_boxes
in
enumerate
(
boxes
):
for
index
,
img_boxes
in
enumerate
(
boxes
):
original_image
=
predicted_data
[
index
].
copy
()
original_image
=
predicted_data
[
index
].
copy
()
...
@@ -255,7 +256,6 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -255,7 +256,6 @@ class ChineseOCRDBCRNN(hub.Module):
tmp_box
=
copy
.
deepcopy
(
boxes
[
num_box
])
tmp_box
=
copy
.
deepcopy
(
boxes
[
num_box
])
img_crop
=
self
.
get_rotate_crop_image
(
original_image
,
tmp_box
)
img_crop
=
self
.
get_rotate_crop_image
(
original_image
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
img_crop_list
.
append
(
img_crop
)
print
(
'img_crop shape {}'
.
format
(
img_crop
.
shape
))
img_crop_list
,
angle_list
=
self
.
_classify_text
(
img_crop_list
,
img_crop_list
,
angle_list
=
self
.
_classify_text
(
img_crop_list
,
angle_classification_thresh
=
angle_classification_thresh
)
angle_classification_thresh
=
angle_classification_thresh
)
rec_results
=
self
.
_recognize_text
(
img_crop_list
)
rec_results
=
self
.
_recognize_text
(
img_crop_list
)
...
@@ -371,7 +371,8 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -371,7 +371,8 @@ class ChineseOCRDBCRNN(hub.Module):
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
norm_img_batch
=
[]
max_wh_ratio
=
0
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
for
ino
in
range
(
beg_img_no
,
end_img_no
):
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
wh_ratio
=
w
*
1.0
/
h
...
@@ -400,10 +401,8 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -400,10 +401,8 @@ class ChineseOCRDBCRNN(hub.Module):
preds
=
preds
[
-
1
]
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
print
(
'preds.shape: {}'
,
preds
.
shape
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
# print('preds_idx: {} \n preds_prob: {}'.format(preds_idx, preds_prob) )
rec_result
=
self
.
char_ops
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
rec_result
=
self
.
char_ops
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
for
rno
in
range
(
len
(
rec_result
)):
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
...
@@ -431,6 +430,7 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -431,6 +430,7 @@ class ChineseOCRDBCRNN(hub.Module):
results
=
self
.
recognize_text
(
paths
=
[
args
.
input_path
],
results
=
self
.
recognize_text
(
paths
=
[
args
.
input_path
],
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
output_dir
=
args
.
output_dir
,
det_db_unclip_ratio
=
args
.
det_db_unclip_ratio
,
visualization
=
args
.
visualization
)
visualization
=
args
.
visualization
)
return
results
return
results
...
@@ -450,6 +450,10 @@ class ChineseOCRDBCRNN(hub.Module):
...
@@ -450,6 +450,10 @@ class ChineseOCRDBCRNN(hub.Module):
type
=
ast
.
literal_eval
,
type
=
ast
.
literal_eval
,
default
=
False
,
default
=
False
,
help
=
"whether to save output as images."
)
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--det_db_unclip_ratio'
,
type
=
float
,
default
=
1.5
,
help
=
"unclip ratio for post processing in DB detection."
)
def
add_module_input_arg
(
self
):
def
add_module_input_arg
(
self
):
"""
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录