Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
03d38486
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1533
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03d38486
编写于
8月 08, 2022
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix amp train for re
上级
071d8327
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
902 addition
and
3 deletion
+902
-3
tools/infer/predict_det_eval.py
tools/infer/predict_det_eval.py
+363
-0
tools/infer/predict_rec_eval.py
tools/infer/predict_rec_eval.py
+534
-0
tools/program.py
tools/program.py
+5
-3
未找到文件。
tools/infer/predict_det_eval.py
0 → 100755
浏览文件 @
03d38486
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
time
import
sys
import
tools.infer.utility
as
utility
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
import
json
logger
=
get_logger
()
class
TextDetector
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
det_algorithm
=
args
.
det_algorithm
self
.
use_onnx
=
args
.
use_onnx
pre_process_list
=
[{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
,
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
}
}]
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_db_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_db_box_thresh
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
elif
self
.
det_algorithm
==
"EAST"
:
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_east_score_thresh
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
elif
self
.
det_algorithm
==
"SAST"
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'resize_long'
:
args
.
det_limit_side_len
}
}
postprocess_params
[
'name'
]
=
'SASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_sast_score_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_sast_nms_thresh
self
.
det_sast_polygon
=
args
.
det_sast_polygon
if
self
.
det_sast_polygon
:
postprocess_params
[
"sample_pts_num"
]
=
6
postprocess_params
[
"expand_scale"
]
=
1.2
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.2
else
:
postprocess_params
[
"sample_pts_num"
]
=
2
postprocess_params
[
"expand_scale"
]
=
1.0
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.3
elif
self
.
det_algorithm
==
"PSE"
:
postprocess_params
[
'name'
]
=
'PSEPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_pse_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_pse_box_thresh
postprocess_params
[
"min_area"
]
=
args
.
det_pse_min_area
postprocess_params
[
"box_type"
]
=
args
.
det_pse_box_type
postprocess_params
[
"scale"
]
=
args
.
det_pse_scale
self
.
det_pse_box_type
=
args
.
det_pse_box_type
elif
self
.
det_algorithm
==
"FCE"
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'rescale_img'
:
[
1080
,
736
]
}
}
postprocess_params
[
'name'
]
=
'FCEPostProcess'
postprocess_params
[
"scales"
]
=
args
.
scales
postprocess_params
[
"alpha"
]
=
args
.
alpha
postprocess_params
[
"beta"
]
=
args
.
beta
postprocess_params
[
"fourier_degree"
]
=
args
.
fourier_degree
postprocess_params
[
"box_type"
]
=
args
.
det_fce_box_type
else
:
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
if
self
.
use_onnx
:
img_h
,
img_w
=
self
.
input_tensor
.
shape
[
2
:]
if
img_h
is
not
None
and
img_w
is
not
None
and
img_h
>
0
and
img_w
>
0
:
pre_process_list
[
0
]
=
{
'DetResizeForTest'
:
{
'image_shape'
:
[
img_h
,
img_w
]
}
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
if
args
.
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
utility
.
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"det"
,
model_precision
=
args
.
precision
,
batch_size
=
1
,
data_shape
=
"dynamic"
,
save_path
=
None
,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
args
.
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
logger
=
logger
)
def
order_points_clockwise
(
self
,
pts
):
rect
=
np
.
zeros
((
4
,
2
),
dtype
=
"float32"
)
s
=
pts
.
sum
(
axis
=
1
)
rect
[
0
]
=
pts
[
np
.
argmin
(
s
)]
rect
[
2
]
=
pts
[
np
.
argmax
(
s
)]
tmp
=
np
.
delete
(
pts
,
(
np
.
argmin
(
s
),
np
.
argmax
(
s
)),
axis
=
0
)
diff
=
np
.
diff
(
np
.
array
(
tmp
),
axis
=
1
)
rect
[
1
]
=
tmp
[
np
.
argmin
(
diff
)]
rect
[
3
]
=
tmp
[
np
.
argmax
(
diff
)]
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
for
pno
in
range
(
points
.
shape
[
0
]):
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
return
points
def
filter_tag_det_res
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
order_points_clockwise
(
box
)
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
rect_width
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
]))
rect_height
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
3
]))
if
rect_width
<=
3
or
rect_height
<=
3
:
continue
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
filter_tag_det_res_only_clip
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
st
=
time
.
time
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
start
()
data
=
transform
(
data
,
self
.
preprocess_op
)
img
,
shape_list
=
data
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
img
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
else
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
preds
[
'f_geo'
]
=
outputs
[
0
]
preds
[
'f_score'
]
=
outputs
[
1
]
elif
self
.
det_algorithm
==
'SAST'
:
preds
[
'f_border'
]
=
outputs
[
0
]
preds
[
'f_score'
]
=
outputs
[
1
]
preds
[
'f_tco'
]
=
outputs
[
2
]
preds
[
'f_tvo'
]
=
outputs
[
3
]
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
]:
preds
[
'maps'
]
=
outputs
[
0
]
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
output
in
enumerate
(
outputs
):
preds
[
'level_{}'
.
format
(
i
)]
=
output
else
:
raise
NotImplementedError
#self.predictor.try_shrink_memory()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
if
(
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
)
or
(
self
.
det_algorithm
in
[
"PSE"
,
"FCE"
]
and
self
.
postprocess_op
.
box_type
==
'poly'
):
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
dt_boxes
,
ori_im
.
shape
)
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
et
=
time
.
time
()
return
dt_boxes
,
et
-
st
if
__name__
==
"__main__"
:
from
ppocr.metrics.eval_det_iou
import
DetectionIoUEvaluator
evaluator
=
DetectionIoUEvaluator
()
args
=
utility
.
parse_args
()
# image_file_list = get_image_file_list(args.image_dir)
def
_check_image_file
(
path
):
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
return
any
([
path
.
lower
().
endswith
(
e
)
for
e
in
img_end
])
def
get_image_file_list_from_txt
(
img_file
):
imgs_lists
=
[]
label_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
root_dir
=
img_file
.
split
(
'/'
)[
0
]
with
open
(
img_file
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
line
=
line
.
replace
(
'
\n
'
,
''
).
split
(
'
\t
'
)
file_path
,
label
=
line
[
0
],
line
[
1
]
file_path
=
os
.
path
.
join
(
root_dir
,
file_path
)
if
os
.
path
.
isfile
(
file_path
)
and
_check_image_file
(
file_path
):
imgs_lists
.
append
(
file_path
)
label_lists
.
append
(
label
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
return
imgs_lists
,
label_lists
image_file_list
,
label_list
=
get_image_file_list_from_txt
(
args
.
image_dir
)
text_detector
=
TextDetector
(
args
)
count
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
2
):
res
=
text_detector
(
img
)
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
save_results
=
[]
results
=
[]
for
idx
in
range
(
len
(
image_file_list
)):
image_file
=
image_file_list
[
idx
]
label
=
json
.
loads
(
label_list
[
idx
])
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
st
=
time
.
time
()
dt_boxes
,
_
=
text_detector
(
img
)
elapse
=
time
.
time
()
-
st
if
count
>
0
:
total_time
+=
elapse
count
+=
1
save_pred
=
os
.
path
.
basename
(
image_file
)
+
"
\t
"
+
str
(
json
.
dumps
([
x
.
tolist
()
for
x
in
dt_boxes
]))
+
"
\n
"
save_results
.
append
(
save_pred
)
# for eval
gt_info_list
=
[]
det_info_list
=
[]
for
dt_box
in
dt_boxes
:
det_info
=
{
'points'
:
np
.
array
(
dt_box
,
dtype
=
np
.
float32
),
'text'
:
''
}
det_info_list
.
append
(
det_info
)
for
lab
in
label
:
gt_info
=
{
'points'
:
np
.
array
(
lab
[
'points'
],
dtype
=
np
.
float32
),
'text'
:
''
,
'ignore'
:
False
}
gt_info_list
.
append
(
gt_info
)
result
=
evaluator
.
evaluate_image
(
gt_info_list
,
det_info_list
)
results
.
append
(
result
)
metrics
=
evaluator
.
combine_results
(
results
)
print
(
'predict det eval on '
,
args
.
image_dir
)
print
(
'metrics: '
,
metrics
)
# logger.info(save_pred)
# logger.info("The predict time of {}: {}".format(image_file, elapse))
# src_im = utility.draw_text_det_res(dt_boxes, image_file)
# img_name_pure = os.path.split(image_file)[-1]
# img_path = os.path.join(draw_img_save,
# "det_res_{}".format(img_name_pure))
# cv2.imwrite(img_path, src_im)
# logger.info("The visualized image saved in {}".format(img_path))
# with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
# f.writelines(save_results)
# f.close()
# if args.benchmark:
# text_detector.autolog.report()
tools/infer/predict_rec_eval.py
0 → 100755
浏览文件 @
03d38486
此差异已折叠。
点击以展开。
tools/program.py
浏览文件 @
03d38486
...
...
@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except
Exception
as
e
:
pass
def
to_float32
(
preds
):
if
isinstance
(
preds
,
dict
):
for
k
in
preds
:
if
isinstance
(
preds
[
k
],
dict
)
or
isinstance
(
preds
[
k
],
list
):
preds
[
k
]
=
to_float32
(
preds
[
k
])
else
:
preds
[
k
]
=
p
reds
[
k
].
astype
(
paddle
.
float32
)
preds
[
k
]
=
p
addle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
elif
isinstance
(
preds
,
list
):
for
k
in
range
(
len
(
preds
)):
if
isinstance
(
preds
[
k
],
dict
):
...
...
@@ -168,11 +169,12 @@ def to_float32(preds):
elif
isinstance
(
preds
[
k
],
list
):
preds
[
k
]
=
to_float32
(
preds
[
k
])
else
:
preds
[
k
]
=
p
reds
[
k
].
astype
(
paddle
.
float32
)
preds
[
k
]
=
p
addle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
else
:
preds
=
preds
.
astype
(
paddle
.
float32
)
preds
[
k
]
=
paddle
.
to_tensor
(
preds
[
k
],
dtype
=
'float32'
)
return
preds
def
train
(
config
,
train_dataloader
,
valid_dataloader
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录