Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
d9d51f7d
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1532
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d9d51f7d
编写于
9月 16, 2020
作者:
Y
yukavio
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm eval_det_utils
上级
d4f1758d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
11 addition
and
188 deletion
+11
-188
deploy/slim/prune/eval_det_utils.py
deploy/slim/prune/eval_det_utils.py
+0
-156
deploy/slim/prune/pruning_and_finetune.py
deploy/slim/prune/pruning_and_finetune.py
+0
-29
deploy/slim/prune/sensitivity_anal.py
deploy/slim/prune/sensitivity_anal.py
+11
-3
未找到文件。
deploy/slim/prune/eval_det_utils.py
已删除
100644 → 0
浏览文件 @
d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
__all__
=
[
'eval_det_run'
]
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
import
cv2
import
json
from
copy
import
deepcopy
from
ppocr.utils.utility
import
create_module
from
ppocr.data.reader_main
import
reader_main
from
tools.eval_utils.eval_det_iou
import
DetectionIoUEvaluator
def
cal_det_res
(
exe
,
config
,
eval_info_dict
):
global_params
=
config
[
'Global'
]
save_res_path
=
global_params
[
'save_res_path'
]
postprocess_params
=
deepcopy
(
config
[
"PostProcess"
])
postprocess_params
.
update
(
global_params
)
postprocess
=
create_module
(
postprocess_params
[
'function'
])
\
(
params
=
postprocess_params
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
save_res_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
save_res_path
))
with
open
(
save_res_path
,
"wb"
)
as
fout
:
tackling_num
=
0
for
data
in
eval_info_dict
[
'reader'
]():
img_num
=
len
(
data
)
tackling_num
=
tackling_num
+
img_num
logger
.
info
(
"test tackling num:%d"
,
tackling_num
)
img_list
=
[]
ratio_list
=
[]
img_name_list
=
[]
for
ino
in
range
(
img_num
):
img_list
.
append
(
data
[
ino
][
0
])
ratio_list
.
append
(
data
[
ino
][
1
])
img_name_list
.
append
(
data
[
ino
][
2
])
try
:
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
except
:
err
=
"concatenate error usually caused by different input image shapes in evaluation or testing.
\n
\
Please set
\"
test_batch_size_per_card
\"
in main yml as 1
\n
\
or add
\"
test_image_shape: [h, w]
\"
in reader yml for EvalReader."
raise
Exception
(
err
)
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
])
outs_dict
=
{}
for
tno
in
range
(
len
(
outs
)):
fetch_name
=
eval_info_dict
[
'fetch_name_list'
][
tno
]
fetch_value
=
np
.
array
(
outs
[
tno
])
outs_dict
[
fetch_name
]
=
fetch_value
dt_boxes_list
=
postprocess
(
outs_dict
,
ratio_list
)
for
ino
in
range
(
img_num
):
dt_boxes
=
dt_boxes_list
[
ino
]
img_name
=
img_name_list
[
ino
]
dt_boxes_json
=
[]
for
box
in
dt_boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
otstr
=
img_name
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
return
def
load_label_infor
(
label_file_path
,
do_ignore
=
False
):
img_name_label_dict
=
{}
with
open
(
label_file_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
substr
=
line
.
decode
().
strip
(
"
\n
"
).
split
(
"
\t
"
)
bbox_infor
=
json
.
loads
(
substr
[
1
])
bbox_num
=
len
(
bbox_infor
)
for
bno
in
range
(
bbox_num
):
text
=
bbox_infor
[
bno
][
'transcription'
]
ignore
=
False
if
text
==
"###"
and
do_ignore
:
ignore
=
True
bbox_infor
[
bno
][
'ignore'
]
=
ignore
img_name_label_dict
[
os
.
path
.
basename
(
substr
[
0
])]
=
bbox_infor
return
img_name_label_dict
def
cal_det_metrics
(
gt_label_path
,
save_res_path
):
"""
calculate the detection metrics
Args:
gt_label_path(string): The groundtruth detection label file path
save_res_path(string): The saved predicted detection label path
return:
claculated metrics including Hmean, precision and recall
"""
evaluator
=
DetectionIoUEvaluator
()
gt_label_infor
=
load_label_infor
(
gt_label_path
,
do_ignore
=
True
)
dt_label_infor
=
load_label_infor
(
save_res_path
)
results
=
[]
for
img_name
in
gt_label_infor
:
gt_label
=
gt_label_infor
[
img_name
]
if
img_name
not
in
dt_label_infor
:
dt_label
=
[]
else
:
dt_label
=
dt_label_infor
[
img_name
]
result
=
evaluator
.
evaluate_image
(
gt_label
,
dt_label
)
results
.
append
(
result
)
methodMetrics
=
evaluator
.
combine_results
(
results
)
return
methodMetrics
def
eval_det_run
(
eval_args
,
mode
=
'eval'
):
exe
=
eval_args
[
'exe'
]
config
=
eval_args
[
'config'
]
eval_info_dict
=
eval_args
[
'eval_info_dict'
]
cal_det_res
(
exe
,
config
,
eval_info_dict
)
save_res_path
=
config
[
'Global'
][
'save_res_path'
]
if
mode
==
"eval"
:
gt_label_path
=
config
[
'EvalReader'
][
'label_file_path'
]
metrics
=
cal_det_metrics
(
gt_label_path
,
save_res_path
)
else
:
gt_label_path
=
config
[
'TestReader'
][
'label_file_path'
]
do_eval
=
config
[
'TestReader'
][
'do_eval'
]
if
do_eval
:
metrics
=
cal_det_metrics
(
gt_label_path
,
save_res_path
)
else
:
metrics
=
{}
return
metrics
[
'hmean'
]
deploy/slim/prune/pruning_and_finetune.py
浏览文件 @
d9d51f7d
...
...
@@ -104,14 +104,6 @@ def main():
# compile program for multi-devices
init_model
(
config
,
train_program
,
exe
)
# params = get_pruned_params(train_program)
'''
sens_file = ['sensitivities_'+ str(x) for x in range(0,4)]
sens = []
for f in sens_file:
sens.append(load_sensitivities(f+'.data'))
sen = merge_sensitive(sens)
'''
sen
=
load_sensitivities
(
"sensitivities_0.data"
)
for
i
in
skip_list
:
sen
.
pop
(
i
)
...
...
@@ -161,28 +153,7 @@ def main():
program
.
train_eval_rec_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
)
def
test_reader
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
print
(
config
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
import
time
starttime
=
time
.
time
()
count
=
0
try
:
for
data
in
train_reader
():
count
+=
1
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
print
(
"reader:"
,
count
,
len
(
data
),
batch_time
)
except
Exception
as
e
:
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
if
__name__
==
'__main__'
:
parser
=
program
.
ArgsParser
()
FLAGS
=
parser
.
parse_args
()
main
()
# test_reader()
deploy/slim/prune/sensitivity_anal.py
浏览文件 @
d9d51f7d
...
...
@@ -42,7 +42,7 @@ import cv2
from
paddle
import
fluid
import
paddleslim
as
slim
from
copy
import
deepcopy
from
eval_det_utils
import
eval_det_run
from
tools.eval_utils.
eval_det_utils
import
eval_det_run
from
tools
import
program
from
ppocr.utils.utility
import
initial_logger
...
...
@@ -65,6 +65,14 @@ def get_pruned_params(program):
return
params
def
eval_function
(
eval_args
,
mode
=
'eval'
):
exe
=
eval_args
[
'exe'
]
config
=
eval_args
[
'config'
]
eval_info_dict
=
eval_args
[
'eval_info_dict'
]
metrics
=
eval_det_run
(
exe
,
config
,
eval_info_dict
,
mode
=
mode
)
return
metrics
[
'hmean'
]
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
...
...
@@ -99,7 +107,7 @@ def main():
'fetch_varname_list'
:
eval_fetch_varname_list
}
eval_args
=
dict
()
eval_args
=
{
'exe'
:
exe
,
'config'
:
config
,
'eval_info_dict'
:
eval_info_dict
}
metrics
=
eval_
det_ru
n
(
eval_args
)
metrics
=
eval_
functio
n
(
eval_args
)
print
(
"Baseline: {}"
.
format
(
metrics
))
params
=
get_pruned_params
(
eval_program
)
...
...
@@ -108,7 +116,7 @@ def main():
eval_program
,
place
,
params
,
eval_
det_ru
n
,
eval_
functio
n
,
sensitivities_file
=
"sensitivities_0.data"
,
pruned_ratios
=
[
0.1
],
eval_args
=
eval_args
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录