Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
d4f1758d
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4f1758d
编写于
9月 15, 2020
作者:
Y
yukavio
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add slim/prune
上级
ed6b2f0c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
609 addition
and
4 deletion
+609
-4
deploy/slim/prune/README.md
deploy/slim/prune/README.md
+40
-0
deploy/slim/prune/eval_det_utils.py
deploy/slim/prune/eval_det_utils.py
+156
-0
deploy/slim/prune/export_prune_model.py
deploy/slim/prune/export_prune_model.py
+81
-0
deploy/slim/prune/pruning_and_finetune.py
deploy/slim/prune/pruning_and_finetune.py
+188
-0
deploy/slim/prune/sensitivity_anal.py
deploy/slim/prune/sensitivity_anal.py
+121
-0
tools/program.py
tools/program.py
+23
-4
未找到文件。
deploy/slim/prune/README.md
0 → 100644
浏览文件 @
d4f1758d
> 运行示例前请先安装develop版本PaddleSlim
# 模型裁剪压缩教程
## 概述
该示例使用PaddleSlim提供的
[
裁剪压缩API
](
https://paddlepaddle.github.io/PaddleSlim/api/prune_api/
)
对OCR模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
-
[
OCR模型的常规训练方法
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md
)
-
[
PaddleSlim使用文档
](
https://paddlepaddle.github.io/PaddleSlim/
)
## 安装PaddleSlim
可按照
[
PaddleSlim使用文档
](
https://paddlepaddle.github.io/PaddleSlim/
)
中的步骤安装PaddleSlim。
## 敏感度分析训练
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析:
```
bash
python deploy/slim/prune/sensitivity_anal.py
-c
configs/det/det_mv3_db.yml
-o
Global.pretrain_weights
=
./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card
=
1
```
## 裁剪模型与fine-tune
```
bash
python deploy/slim/prune/pruning_and_finetune.py
-c
configs/det/det_mv3_db.yml
-o
Global.pretrain_weights
=
./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card
=
1
```
## 评估并导出
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model,用于预测部署:
```
bash
python deploy/slim/prune/export_prune_model.py
-c
configs/det/det_mv3_db.yml
-o
Global.pretrain_weights
=
./output/det_db/best_accuracy Global.test_batch_size_per_card
=
1 Global.save_inference_dir
=
inference_model
```
deploy/slim/prune/eval_det_utils.py
0 → 100644
浏览文件 @
d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
__all__
=
[
'eval_det_run'
]
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
import
cv2
import
json
from
copy
import
deepcopy
from
ppocr.utils.utility
import
create_module
from
ppocr.data.reader_main
import
reader_main
from
tools.eval_utils.eval_det_iou
import
DetectionIoUEvaluator
def
cal_det_res
(
exe
,
config
,
eval_info_dict
):
global_params
=
config
[
'Global'
]
save_res_path
=
global_params
[
'save_res_path'
]
postprocess_params
=
deepcopy
(
config
[
"PostProcess"
])
postprocess_params
.
update
(
global_params
)
postprocess
=
create_module
(
postprocess_params
[
'function'
])
\
(
params
=
postprocess_params
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
save_res_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
save_res_path
))
with
open
(
save_res_path
,
"wb"
)
as
fout
:
tackling_num
=
0
for
data
in
eval_info_dict
[
'reader'
]():
img_num
=
len
(
data
)
tackling_num
=
tackling_num
+
img_num
logger
.
info
(
"test tackling num:%d"
,
tackling_num
)
img_list
=
[]
ratio_list
=
[]
img_name_list
=
[]
for
ino
in
range
(
img_num
):
img_list
.
append
(
data
[
ino
][
0
])
ratio_list
.
append
(
data
[
ino
][
1
])
img_name_list
.
append
(
data
[
ino
][
2
])
try
:
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
except
:
err
=
"concatenate error usually caused by different input image shapes in evaluation or testing.
\n
\
Please set
\"
test_batch_size_per_card
\"
in main yml as 1
\n
\
or add
\"
test_image_shape: [h, w]
\"
in reader yml for EvalReader."
raise
Exception
(
err
)
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
])
outs_dict
=
{}
for
tno
in
range
(
len
(
outs
)):
fetch_name
=
eval_info_dict
[
'fetch_name_list'
][
tno
]
fetch_value
=
np
.
array
(
outs
[
tno
])
outs_dict
[
fetch_name
]
=
fetch_value
dt_boxes_list
=
postprocess
(
outs_dict
,
ratio_list
)
for
ino
in
range
(
img_num
):
dt_boxes
=
dt_boxes_list
[
ino
]
img_name
=
img_name_list
[
ino
]
dt_boxes_json
=
[]
for
box
in
dt_boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
otstr
=
img_name
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
return
def
load_label_infor
(
label_file_path
,
do_ignore
=
False
):
img_name_label_dict
=
{}
with
open
(
label_file_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
substr
=
line
.
decode
().
strip
(
"
\n
"
).
split
(
"
\t
"
)
bbox_infor
=
json
.
loads
(
substr
[
1
])
bbox_num
=
len
(
bbox_infor
)
for
bno
in
range
(
bbox_num
):
text
=
bbox_infor
[
bno
][
'transcription'
]
ignore
=
False
if
text
==
"###"
and
do_ignore
:
ignore
=
True
bbox_infor
[
bno
][
'ignore'
]
=
ignore
img_name_label_dict
[
os
.
path
.
basename
(
substr
[
0
])]
=
bbox_infor
return
img_name_label_dict
def
cal_det_metrics
(
gt_label_path
,
save_res_path
):
"""
calculate the detection metrics
Args:
gt_label_path(string): The groundtruth detection label file path
save_res_path(string): The saved predicted detection label path
return:
claculated metrics including Hmean, precision and recall
"""
evaluator
=
DetectionIoUEvaluator
()
gt_label_infor
=
load_label_infor
(
gt_label_path
,
do_ignore
=
True
)
dt_label_infor
=
load_label_infor
(
save_res_path
)
results
=
[]
for
img_name
in
gt_label_infor
:
gt_label
=
gt_label_infor
[
img_name
]
if
img_name
not
in
dt_label_infor
:
dt_label
=
[]
else
:
dt_label
=
dt_label_infor
[
img_name
]
result
=
evaluator
.
evaluate_image
(
gt_label
,
dt_label
)
results
.
append
(
result
)
methodMetrics
=
evaluator
.
combine_results
(
results
)
return
methodMetrics
def
eval_det_run
(
eval_args
,
mode
=
'eval'
):
exe
=
eval_args
[
'exe'
]
config
=
eval_args
[
'config'
]
eval_info_dict
=
eval_args
[
'eval_info_dict'
]
cal_det_res
(
exe
,
config
,
eval_info_dict
)
save_res_path
=
config
[
'Global'
][
'save_res_path'
]
if
mode
==
"eval"
:
gt_label_path
=
config
[
'EvalReader'
][
'label_file_path'
]
metrics
=
cal_det_metrics
(
gt_label_path
,
save_res_path
)
else
:
gt_label_path
=
config
[
'TestReader'
][
'label_file_path'
]
do_eval
=
config
[
'TestReader'
][
'do_eval'
]
if
do_eval
:
metrics
=
cal_det_metrics
(
gt_label_path
,
save_res_path
)
else
:
metrics
=
{}
return
metrics
[
'hmean'
]
deploy/slim/prune/export_prune_model.py
0 → 100644
浏览文件 @
d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
,
'tools'
))
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
import
program
from
paddle
import
fluid
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.save_load
import
init_model
from
paddleslim.prune
import
load_model
def
main
():
startup_prog
,
eval_program
,
place
,
config
,
_
=
program
.
preprocess
()
feeded_var_names
,
target_vars
,
fetches_var_name
=
program
.
build_export
(
config
,
eval_program
,
startup_prog
)
eval_program
=
eval_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
if
config
[
'Global'
][
'checkpoints'
]
is
not
None
:
path
=
config
[
'Global'
][
'checkpoints'
]
else
:
path
=
config
[
'Global'
][
'pretrain_weights'
]
load_model
(
exe
,
eval_program
,
path
)
save_inference_dir
=
config
[
'Global'
][
'save_inference_dir'
]
if
not
os
.
path
.
exists
(
save_inference_dir
):
os
.
makedirs
(
save_inference_dir
)
fluid
.
io
.
save_inference_model
(
dirname
=
save_inference_dir
,
feeded_var_names
=
feeded_var_names
,
main_program
=
eval_program
,
target_vars
=
target_vars
,
executor
=
exe
,
model_filename
=
'model'
,
params_filename
=
'params'
)
print
(
"inference model saved in {}/model and {}/params"
.
format
(
save_inference_dir
,
save_inference_dir
))
print
(
"save success, output_name_list:"
,
fetches_var_name
)
if
__name__
==
'__main__'
:
main
()
deploy/slim/prune/pruning_and_finetune.py
0 → 100644
浏览文件 @
d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
numpy
as
np
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
,
'tools'
))
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
import
tools.program
as
program
from
paddle
import
fluid
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.data.reader_main
import
reader_main
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.character
import
CharacterOps
from
ppocr.utils.utility
import
initial_logger
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.analysis
import
flops
from
paddleslim.core.graph_wrapper
import
*
from
paddleslim.prune
import
load_sensitivities
,
get_ratios_by_loss
,
merge_sensitive
logger
=
initial_logger
()
skip_list
=
[
'conv10_linear_weights'
,
'conv11_linear_weights'
,
'conv12_expand_weights'
,
'conv12_linear_weights'
,
'conv12_se_2_weights'
,
'conv13_linear_weights'
,
'conv2_linear_weights'
,
'conv4_linear_weights'
,
'conv5_expand_weights'
,
'conv5_linear_weights'
,
'conv5_se_2_weights'
,
'conv6_linear_weights'
,
'conv7_linear_weights'
,
'conv8_expand_weights'
,
'conv8_linear_weights'
,
'conv9_expand_weights'
,
'conv9_linear_weights'
]
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
logger
.
info
(
config
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
program
.
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
startup_program
=
fluid
.
Program
()
train_program
=
fluid
.
Program
()
train_build_outputs
=
program
.
build
(
config
,
train_program
,
startup_program
,
mode
=
'train'
)
train_loader
=
train_build_outputs
[
0
]
train_fetch_name_list
=
train_build_outputs
[
1
]
train_fetch_varname_list
=
train_build_outputs
[
2
]
train_opt_loss_name
=
train_build_outputs
[
3
]
eval_program
=
fluid
.
Program
()
eval_build_outputs
=
program
.
build
(
config
,
eval_program
,
startup_program
,
mode
=
'eval'
)
eval_fetch_name_list
=
eval_build_outputs
[
1
]
eval_fetch_varname_list
=
eval_build_outputs
[
2
]
eval_program
=
eval_program
.
clone
(
for_test
=
True
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
train_loader
.
set_sample_list_generator
(
train_reader
,
places
=
place
)
eval_reader
=
reader_main
(
config
=
config
,
mode
=
"eval"
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
# compile program for multi-devices
init_model
(
config
,
train_program
,
exe
)
# params = get_pruned_params(train_program)
'''
sens_file = ['sensitivities_'+ str(x) for x in range(0,4)]
sens = []
for f in sens_file:
sens.append(load_sensitivities(f+'.data'))
sen = merge_sensitive(sens)
'''
sen
=
load_sensitivities
(
"sensitivities_0.data"
)
for
i
in
skip_list
:
sen
.
pop
(
i
)
back_bone_list
=
[
'conv'
+
str
(
x
)
for
x
in
range
(
1
,
5
)]
for
i
in
back_bone_list
:
for
key
in
list
(
sen
.
keys
()):
if
i
+
'_'
in
key
:
sen
.
pop
(
key
)
ratios
=
get_ratios_by_loss
(
sen
,
0.03
)
logger
.
info
(
"FLOPs before pruning: {}"
.
format
(
flops
(
eval_program
)))
pruner
=
Pruner
(
criterion
=
'geometry_median'
)
print
(
"ratios: {}"
.
format
(
ratios
))
pruned_val_program
,
_
,
_
=
pruner
.
prune
(
eval_program
,
fluid
.
global_scope
(),
params
=
ratios
.
keys
(),
ratios
=
ratios
.
values
(),
place
=
place
,
only_graph
=
True
)
pruned_program
,
_
,
_
=
pruner
.
prune
(
train_program
,
fluid
.
global_scope
(),
params
=
ratios
.
keys
(),
ratios
=
ratios
.
values
(),
place
=
place
)
logger
.
info
(
"FLOPs after pruning: {}"
.
format
(
flops
(
pruned_val_program
)))
train_compile_program
=
program
.
create_multi_devices_program
(
pruned_program
,
train_opt_loss_name
)
train_info_dict
=
{
'compile_program'
:
train_compile_program
,
\
'train_program'
:
pruned_program
,
\
'reader'
:
train_loader
,
\
'fetch_name_list'
:
train_fetch_name_list
,
\
'fetch_varname_list'
:
train_fetch_varname_list
}
eval_info_dict
=
{
'program'
:
pruned_val_program
,
\
'reader'
:
eval_reader
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_varname_list'
:
eval_fetch_varname_list
}
if
alg
in
[
'EAST'
,
'DB'
]:
program
.
train_eval_det_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
,
is_pruning
=
True
)
else
:
program
.
train_eval_rec_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
)
def
test_reader
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
print
(
config
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
import
time
starttime
=
time
.
time
()
count
=
0
try
:
for
data
in
train_reader
():
count
+=
1
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
print
(
"reader:"
,
count
,
len
(
data
),
batch_time
)
except
Exception
as
e
:
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
if
__name__
==
'__main__'
:
parser
=
program
.
ArgsParser
()
FLAGS
=
parser
.
parse_args
()
main
()
# test_reader()
deploy/slim/prune/sensitivity_anal.py
0 → 100644
浏览文件 @
d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
,
'tools'
))
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
import
json
import
cv2
from
paddle
import
fluid
import
paddleslim
as
slim
from
copy
import
deepcopy
from
eval_det_utils
import
eval_det_run
from
tools
import
program
from
ppocr.utils.utility
import
initial_logger
from
ppocr.data.reader_main
import
reader_main
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.character
import
CharacterOps
from
ppocr.utils.utility
import
create_module
from
ppocr.data.reader_main
import
reader_main
logger
=
initial_logger
()
def
get_pruned_params
(
program
):
params
=
[]
for
param
in
program
.
global_block
().
all_parameters
():
if
len
(
param
.
shape
)
==
4
and
'depthwise'
not
in
param
.
name
and
'transpose'
not
in
param
.
name
:
params
.
append
(
param
.
name
)
return
params
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
logger
.
info
(
config
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
program
.
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
startup_prog
=
fluid
.
Program
()
eval_program
=
fluid
.
Program
()
eval_build_outputs
=
program
.
build
(
config
,
eval_program
,
startup_prog
,
mode
=
'test'
)
eval_fetch_name_list
=
eval_build_outputs
[
1
]
eval_fetch_varname_list
=
eval_build_outputs
[
2
]
eval_program
=
eval_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
init_model
(
config
,
eval_program
,
exe
)
eval_reader
=
reader_main
(
config
=
config
,
mode
=
"eval"
)
eval_info_dict
=
{
'program'
:
eval_program
,
\
'reader'
:
eval_reader
,
\
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_varname_list'
:
eval_fetch_varname_list
}
eval_args
=
dict
()
eval_args
=
{
'exe'
:
exe
,
'config'
:
config
,
'eval_info_dict'
:
eval_info_dict
}
metrics
=
eval_det_run
(
eval_args
)
print
(
"Baseline: {}"
.
format
(
metrics
))
params
=
get_pruned_params
(
eval_program
)
print
(
'Start to analyze'
)
sens_0
=
slim
.
prune
.
sensitivity
(
eval_program
,
place
,
params
,
eval_det_run
,
sensitivities_file
=
"sensitivities_0.data"
,
pruned_ratios
=
[
0.1
],
eval_args
=
eval_args
,
criterion
=
'geometry_median'
)
if
__name__
==
'__main__'
:
parser
=
program
.
ArgsParser
()
FLAGS
=
parser
.
parse_args
()
main
()
tools/program.py
浏览文件 @
d4f1758d
...
...
@@ -33,6 +33,7 @@ from eval_utils.eval_rec_utils import eval_rec_run
from
ppocr.utils.save_load
import
save_model
import
numpy
as
np
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
,
CharacterOps
import
paddleslim
as
slim
class
ArgsParser
(
ArgumentParser
):
...
...
@@ -238,7 +239,11 @@ def create_multi_devices_program(program, loss_var_name):
return
compile_program
def
train_eval_det_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
):
def
train_eval_det_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
,
is_pruning
=
False
):
train_batch_id
=
0
log_smooth_window
=
config
[
'Global'
][
'log_smooth_window'
]
epoch_num
=
config
[
'Global'
][
'epoch_num'
]
...
...
@@ -294,7 +299,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
best_batch_id
=
train_batch_id
best_epoch
=
epoch
save_path
=
save_model_dir
+
"/best_accuracy"
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
if
is_pruning
:
slim
.
prune
.
save_model
(
exe
,
train_info_dict
[
'train_program'
],
save_path
)
else
:
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
strs
=
'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'
.
format
(
train_batch_id
,
metrics
,
best_eval_hmean
,
best_epoch
,
best_batch_id
)
...
...
@@ -305,10 +316,18 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
train_loader
.
reset
()
if
epoch
==
0
and
save_epoch_step
==
1
:
save_path
=
save_model_dir
+
"/iter_epoch_0"
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
if
is_pruning
:
slim
.
prune
.
save_model
(
exe
,
train_info_dict
[
'train_program'
],
save_path
)
else
:
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
if
epoch
>
0
and
epoch
%
save_epoch_step
==
0
:
save_path
=
save_model_dir
+
"/iter_epoch_%d"
%
(
epoch
)
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
if
is_pruning
:
slim
.
prune
.
save_model
(
exe
,
train_info_dict
[
'train_program'
],
save_path
)
else
:
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录