Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
03d38486
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03d38486
编写于
8月 08, 2022
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix amp train for re
上级
071d8327
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
902 addition
and
3 deletion
+902
-3
tools/infer/predict_det_eval.py
tools/infer/predict_det_eval.py
+363
-0
tools/infer/predict_rec_eval.py
tools/infer/predict_rec_eval.py
+534
-0
tools/program.py
tools/program.py
+5
-3
未找到文件。
tools/infer/predict_det_eval.py
0 → 100755
浏览文件 @
03d38486
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
time
import
sys
import
tools.infer.utility
as
utility
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
import
json
logger
=
get_logger
()
class
TextDetector
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
det_algorithm
=
args
.
det_algorithm
self
.
use_onnx
=
args
.
use_onnx
pre_process_list
=
[{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
,
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
}
}]
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_db_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_db_box_thresh
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
elif
self
.
det_algorithm
==
"EAST"
:
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_east_score_thresh
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
elif
self
.
det_algorithm
==
"SAST"
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'resize_long'
:
args
.
det_limit_side_len
}
}
postprocess_params
[
'name'
]
=
'SASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_sast_score_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_sast_nms_thresh
self
.
det_sast_polygon
=
args
.
det_sast_polygon
if
self
.
det_sast_polygon
:
postprocess_params
[
"sample_pts_num"
]
=
6
postprocess_params
[
"expand_scale"
]
=
1.2
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.2
else
:
postprocess_params
[
"sample_pts_num"
]
=
2
postprocess_params
[
"expand_scale"
]
=
1.0
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.3
elif
self
.
det_algorithm
==
"PSE"
:
postprocess_params
[
'name'
]
=
'PSEPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_pse_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_pse_box_thresh
postprocess_params
[
"min_area"
]
=
args
.
det_pse_min_area
postprocess_params
[
"box_type"
]
=
args
.
det_pse_box_type
postprocess_params
[
"scale"
]
=
args
.
det_pse_scale
self
.
det_pse_box_type
=
args
.
det_pse_box_type
elif
self
.
det_algorithm
==
"FCE"
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'rescale_img'
:
[
1080
,
736
]
}
}
postprocess_params
[
'name'
]
=
'FCEPostProcess'
postprocess_params
[
"scales"
]
=
args
.
scales
postprocess_params
[
"alpha"
]
=
args
.
alpha
postprocess_params
[
"beta"
]
=
args
.
beta
postprocess_params
[
"fourier_degree"
]
=
args
.
fourier_degree
postprocess_params
[
"box_type"
]
=
args
.
det_fce_box_type
else
:
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
if
self
.
use_onnx
:
img_h
,
img_w
=
self
.
input_tensor
.
shape
[
2
:]
if
img_h
is
not
None
and
img_w
is
not
None
and
img_h
>
0
and
img_w
>
0
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'image_shape'
:
[
img_h
,
img_w
]
}
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
if
args
.
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
utility
.
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"det"
,
model_precision
=
args
.
precision
,
batch_size
=
1
,
data_shape
=
"dynamic"
,
save_path
=
None
,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
args
.
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
logger
=
logger
)
def
order_points_clockwise
(
self
,
pts
):
rect
=
np
.
zeros
((
4
,
2
),
dtype
=
"float32"
)
s
=
pts
.
sum
(
axis
=
1
)
rect
[
0
]
=
pts
[
np
.
argmin
(
s
)]
rect
[
2
]
=
pts
[
np
.
argmax
(
s
)]
tmp
=
np
.
delete
(
pts
,
(
np
.
argmin
(
s
),
np
.
argmax
(
s
)),
axis
=
0
)
diff
=
np
.
diff
(
np
.
array
(
tmp
),
axis
=
1
)
rect
[
1
]
=
tmp
[
np
.
argmin
(
diff
)]
rect
[
3
]
=
tmp
[
np
.
argmax
(
diff
)]
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
for
pno
in
range
(
points
.
shape
[
0
]):
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
return
points
def
filter_tag_det_res
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
order_points_clockwise
(
box
)
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
rect_width
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
]))
rect_height
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
3
]))
if
rect_width
<=
3
or
rect_height
<=
3
:
continue
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
filter_tag_det_res_only_clip
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
st
=
time
.
time
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
start
()
data
=
transform
(
data
,
self
.
preprocess_op
)
img
,
shape_list
=
data
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
img
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
else
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
preds
[
'f_geo'
]
=
outputs
[
0
]
preds
[
'f_score'
]
=
outputs
[
1
]
elif
self
.
det_algorithm
==
'SAST'
:
preds
[
'f_border'
]
=
outputs
[
0
]
preds
[
'f_score'
]
=
outputs
[
1
]
preds
[
'f_tco'
]
=
outputs
[
2
]
preds
[
'f_tvo'
]
=
outputs
[
3
]
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
]:
preds
[
'maps'
]
=
outputs
[
0
]
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
output
in
enumerate
(
outputs
):
preds
[
'level_{}'
.
format
(
i
)]
=
output
else
:
raise
NotImplementedError
#self.predictor.try_shrink_memory()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
if
(
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
)
or
(
self
.
det_algorithm
in
[
"PSE"
,
"FCE"
]
and
self
.
postprocess_op
.
box_type
==
'poly'
):
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
dt_boxes
,
ori_im
.
shape
)
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
et
=
time
.
time
()
return
dt_boxes
,
et
-
st
if
__name__
==
"__main__"
:
from
ppocr.metrics.eval_det_iou
import
DetectionIoUEvaluator
evaluator
=
DetectionIoUEvaluator
()
args
=
utility
.
parse_args
()
# image_file_list = get_image_file_list(args.image_dir)
def
_check_image_file
(
path
):
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
return
any
([
path
.
lower
().
endswith
(
e
)
for
e
in
img_end
])
def
get_image_file_list_from_txt
(
img_file
):
imgs_lists
=
[]
label_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
root_dir
=
img_file
.
split
(
'/'
)[
0
]
with
open
(
img_file
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
line
=
line
.
replace
(
'
\n
'
,
''
).
split
(
'
\t
'
)
file_path
,
label
=
line
[
0
],
line
[
1
]
file_path
=
os
.
path
.
join
(
root_dir
,
file_path
)
if
os
.
path
.
isfile
(
file_path
)
and
_check_image_file
(
file_path
):
imgs_lists
.
append
(
file_path
)
label_lists
.
append
(
label
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
return
imgs_lists
,
label_lists
image_file_list
,
label_list
=
get_image_file_list_from_txt
(
args
.
image_dir
)
text_detector
=
TextDetector
(
args
)
count
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
2
):
res
=
text_detector
(
img
)
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
save_results
=
[]
results
=
[]
for
idx
in
range
(
len
(
image_file_list
)):
image_file
=
image_file_list
[
idx
]
label
=
json
.
loads
(
label_list
[
idx
])
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
st
=
time
.
time
()
dt_boxes
,
_
=
text_detector
(
img
)
elapse
=
time
.
time
()
-
st
if
count
>
0
:
total_time
+=
elapse
count
+=
1
save_pred
=
os
.
path
.
basename
(
image_file
)
+
"
\t
"
+
str
(
json
.
dumps
([
x
.
tolist
()
for
x
in
dt_boxes
]))
+
"
\n
"
save_results
.
append
(
save_pred
)
# for eval
gt_info_list
=
[]
det_info_list
=
[]
for
dt_box
in
dt_boxes
:
det_info
=
{
'points'
:
np
.
array
(
dt_box
,
dtype
=
np
.
float32
),
'text'
:
''
}
det_info_list
.
append
(
det_info
)
for
lab
in
label
:
gt_info
=
{
'points'
:
np
.
array
(
lab
[
'points'
],
dtype
=
np
.
float32
),
'text'
:
''
,
'ignore'
:
False
}
gt_info_list
.
append
(
gt_info
)
result
=
evaluator
.
evaluate_image
(
gt_info_list
,
det_info_list
)
results
.
append
(
result
)
metrics
=
evaluator
.
combine_results
(
results
)
print
(
'predict det eval on '
,
args
.
image_dir
)
print
(
'metrics: '
,
metrics
)
# logger.info(save_pred)
# logger.info("The predict time of {}: {}".format(image_file, elapse))
# src_im = utility.draw_text_det_res(dt_boxes, image_file)
# img_name_pure = os.path.split(image_file)[-1]
# img_path = os.path.join(draw_img_save,
# "det_res_{}".format(img_name_pure))
# cv2.imwrite(img_path, src_im)
# logger.info("The visualized image saved in {}".format(img_path))
# with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
# f.writelines(save_results)
# f.close()
# if args.benchmark:
# text_detector.autolog.report()
tools/infer/predict_rec_eval.py
0 → 100755
浏览文件 @
03d38486
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
PIL
import
Image
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle
import
tools.infer.utility
as
utility
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
logger
=
get_logger
()
class
TextRecognizer
(
object
):
def
__init__
(
self
,
args
):
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
postprocess_params
=
{
'name'
:
'CTCLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
if
self
.
rec_algorithm
==
"SRN"
:
postprocess_params
=
{
'name'
:
'SRNLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
"RARE"
:
postprocess_params
=
{
'name'
:
'AttnLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
'NRTR'
:
postprocess_params
=
{
'name'
:
'NRTRLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
'ViTSTR'
:
postprocess_params
=
{
'name'
:
'ViTSTRLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
'ABINet'
:
postprocess_params
=
{
'name'
:
'ABINetLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
self
.
benchmark
=
args
.
benchmark
self
.
use_onnx
=
args
.
use_onnx
if
args
.
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
utility
.
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"rec"
,
model_precision
=
args
.
precision
,
batch_size
=
args
.
rec_batch_num
,
data_shape
=
"dynamic"
,
save_path
=
None
,
#args.save_log_path,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
args
.
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
0
,
logger
=
logger
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
rec_algorithm
==
'NRTR'
or
self
.
rec_algorithm
==
'ViTSTR'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
# return padding_im
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
if
self
.
rec_algorithm
==
'ViTSTR'
:
img
=
image_pil
.
resize
([
imgW
,
imgH
],
Image
.
BICUBIC
)
else
:
img
=
image_pil
.
resize
([
imgW
,
imgH
],
Image
.
ANTIALIAS
)
img
=
np
.
array
(
img
)
norm_img
=
np
.
expand_dims
(
img
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
if
self
.
rec_algorithm
==
'ViTSTR'
:
norm_img
=
norm_img
.
astype
(
np
.
float32
)
/
255.
else
:
norm_img
=
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
return
norm_img
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
imgH
*
max_wh_ratio
))
if
self
.
use_onnx
:
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
if
w
is
not
None
and
w
>
0
:
imgW
=
w
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
if
self
.
rec_algorithm
==
'RARE'
:
if
resized_w
>
self
.
rec_image_shape
[
2
]:
resized_w
=
self
.
rec_image_shape
[
2
]
imgW
=
self
.
rec_image_shape
[
2
]
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
encoder_word_pos
=
encoder_word_pos
.
astype
(
np
.
int64
)
gsrm_word_pos
=
gsrm_word_pos
.
astype
(
np
.
int64
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
resize_norm_img_svtr
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_abinet
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
/
255.
mean
=
np
.
array
([
0.485
,
0.456
,
0.406
])
std
=
np
.
array
([
0.229
,
0.224
,
0.225
])
resized_image
=
(
resized_image
-
mean
[
None
,
None
,
...])
/
std
[
None
,
None
,
...]
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
resized_image
=
resized_image
.
astype
(
'float32'
)
return
resized_image
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
st
=
time
.
time
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
start
()
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
# max_wh_ratio = 0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SRN"
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
elif
self
.
rec_algorithm
==
"SVTR"
:
norm_img
=
self
.
resize_norm_img_svtr
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"ABINet"
:
norm_img
=
self
.
resize_norm_img_abinet
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
rec_algorithm
==
"SRN"
:
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
else
:
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
len
(
outputs
)
!=
1
:
preds
=
outputs
else
:
preds
=
outputs
[
0
]
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
return
rec_res
,
time
.
time
()
-
st
def
main
(
args
):
# image_file_list = get_image_file_list(args.image_dir)
def
_check_image_file
(
path
):
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
return
any
([
path
.
lower
().
endswith
(
e
)
for
e
in
img_end
])
def
get_image_file_list_from_txt
(
img_file
):
imgs_lists
=
[]
label_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
root_dir
=
img_file
.
split
(
'/'
)[
0
]
with
open
(
img_file
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
line
=
line
.
replace
(
'
\n
'
,
''
).
split
(
'
\t
'
)
file_path
,
label
=
line
[
0
],
line
[
1
]
file_path
=
os
.
path
.
join
(
root_dir
,
file_path
)
if
os
.
path
.
isfile
(
file_path
)
and
_check_image_file
(
file_path
):
imgs_lists
.
append
(
file_path
)
label_lists
.
append
(
label
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
return
imgs_lists
,
label_lists
image_file_list
,
label_list
=
get_image_file_list_from_txt
(
args
.
image_dir
)
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
img_list
=
[]
logger
.
info
(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
# warmup 2 times
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
48
,
320
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
2
):
res
=
text_recognizer
([
img
]
*
int
(
args
.
rec_batch_num
))
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
valid_image_file_list
.
append
(
image_file
)
img_list
.
append
(
img
)
try
:
rec_res
,
_
=
text_recognizer
(
img_list
)
except
Exception
as
E
:
logger
.
info
(
traceback
.
format_exc
())
logger
.
info
(
E
)
exit
()
correct_num
=
0
for
ino
in
range
(
len
(
img_list
)):
pred
=
rec_res
[
ino
][
0
]
gt
=
label_list
[
ino
]
if
pred
==
gt
:
correct_num
+=
1
acc
=
correct_num
*
1.0
/
len
(
img_list
)
print
(
'predict rec eval on '
,
args
.
image_dir
)
print
(
'acc: '
,
acc
)
# for debug bad case
bad_case_lines
=
[]
for
ino
in
range
(
len
(
img_list
)):
pred
=
rec_res
[
ino
][
0
]
gt
=
label_list
[
ino
]
if
pred
!=
gt
and
len
(
gt
)
<=
25
:
bad_case
=
valid_image_file_list
[
ino
]
+
'
\t
'
+
'pred:'
+
pred
+
'
\t
'
+
'gt:'
+
gt
+
'
\n
'
bad_case_lines
.
append
(
bad_case
)
with
open
(
'bad_case_hwdb2.txt'
,
'a+'
)
as
f
:
f
.
writelines
(
bad_case_lines
)
# end debug case
if
args
.
benchmark
:
text_recognizer
.
autolog
.
report
()
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
tools/program.py
浏览文件 @
03d38486
...
...
@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except
Exception
as
e
:
pass
def
to_float32
(
preds
):
if
isinstance
(
preds
,
dict
):
for
k
in
preds
:
if
isinstance
(
preds
[
k
],
dict
)
or
isinstance
(
preds
[
k
],
list
):
preds
[
k
]
=
to_float32
(
preds
[
k
])
else
:
preds
[
k
]
=
p
reds
[
k
].
astype
(
paddle
.
float32
)
preds
[
k
]
=
p
addle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
elif
isinstance
(
preds
,
list
):
for
k
in
range
(
len
(
preds
)):
if
isinstance
(
preds
[
k
],
dict
):
...
...
@@ -168,11 +169,12 @@ def to_float32(preds):
elif
isinstance
(
preds
[
k
],
list
):
preds
[
k
]
=
to_float32
(
preds
[
k
])
else
:
preds
[
k
]
=
p
reds
[
k
].
astype
(
paddle
.
float32
)
preds
[
k
]
=
p
addle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
else
:
preds
=
preds
.
astype
(
paddle
.
float32
)
preds
[
k
]
=
paddle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
return
preds
def
train
(
config
,
train_dataloader
,
valid_dataloader
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录