Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
366eb59c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
366eb59c
编写于
3月 23, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
3月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add prune export_model (#378)
上级
cc84b7ab
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
183 addition
and
2 deletion
+183
-2
slim/extensions/distill_pruned_model/README.md
slim/extensions/distill_pruned_model/README.md
+12
-0
slim/prune/README.md
slim/prune/README.md
+13
-1
slim/prune/eval.py
slim/prune/eval.py
+1
-1
slim/prune/export_model.py
slim/prune/export_model.py
+157
-0
未找到文件。
slim/extensions/distill_pruned_model/README.md
浏览文件 @
366eb59c
...
...
@@ -67,3 +67,15 @@ python ../../prune/eval.py \
--pruned_ratios="0.2,0.3,0.4" \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
## 6. 模型导出
如果想要将剪裁模型接入到C++预测库或者Serving服务,可通过
`../../prune/export_model.py`
导出该模型。
```
python ../../prune/export_model.py \
-c ../../../configs/yolov3_mobilenet_v1_voc.yml \
--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \
--pruned_ratios="0.2,0.3,0.4" \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
slim/prune/README.md
浏览文件 @
366eb59c
...
...
@@ -74,7 +74,19 @@ python eval.py \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
## 7. 扩展模型
## 7. 模型导出
如果想要将剪裁模型接入到C++预测库或者Serving服务,可通过
`export_model.py`
导出该模型。
```
python export_model.py \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \
--pruned_ratios="0.2,0.3,0.4" \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
## 8. 扩展模型
如果需要对自己的模型进行修改,可以参考
`prune.py`
中对
`paddleslim.prune.Pruner`
接口的调用方式,基于自己的模型训练脚本进行修改。
本节我们介绍的剪裁示例,需要用户根据先验知识指定每层的剪裁率,除此之外,PaddleSlim还提供了敏感度分析等功能,协助用户选择合适的剪裁率。更多详情请参考:
[
PaddleSlim使用文档
](
https://paddlepaddle.github.io/PaddleSlim/
)
slim/prune/eval.py
浏览文件 @
366eb59c
...
...
@@ -176,7 +176,7 @@ def main():
# load model
exe
.
run
(
startup_prog
)
if
'weights'
in
cfg
:
checkpoint
.
load_
params
(
exe
,
eval_prog
,
cfg
.
weights
)
checkpoint
.
load_
checkpoint
(
exe
,
eval_prog
,
cfg
.
weights
)
results
=
eval_run
(
exe
,
compile_program
,
loader
,
keys
,
values
,
cls
,
cfg
,
sub_eval_prog
,
sub_keys
,
sub_values
)
...
...
slim/prune/export_model.py
0 → 100644
浏览文件 @
366eb59c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.utils.cli
import
ArgsParser
import
ppdet.utils.checkpoint
as
checkpoint
from
paddleslim.prune
import
Pruner
from
paddleslim.analysis
import
flops
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
prune_feed_vars
(
feeded_var_names
,
target_vars
,
prog
):
"""
Filter out feed variables which are not in program,
pruned feed variables are only used in post processing
on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox
in image.
"""
exist_var_names
=
[]
prog
=
prog
.
clone
()
prog
=
prog
.
_prune
(
targets
=
target_vars
)
global_block
=
prog
.
global_block
()
for
name
in
feeded_var_names
:
try
:
v
=
global_block
.
var
(
name
)
exist_var_names
.
append
(
str
(
v
.
name
))
except
Exception
:
logger
.
info
(
'save_inference_model pruned unused feed '
'variables {}'
.
format
(
name
))
pass
return
exist_var_names
def
save_infer_model
(
FLAGS
,
exe
,
feed_vars
,
test_fetches
,
infer_prog
):
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
cfg_name
)
feed_var_names
=
[
var
.
name
for
var
in
feed_vars
.
values
()]
target_vars
=
list
(
test_fetches
.
values
())
feed_var_names
=
prune_feed_vars
(
feed_var_names
,
target_vars
,
infer_prog
)
logger
.
info
(
"Export inference model to {}, input: {}, output: "
"{}..."
.
format
(
save_dir
,
feed_var_names
,
[
str
(
var
.
name
)
for
var
in
target_vars
]))
fluid
.
io
.
save_inference_model
(
save_dir
,
feeded_var_names
=
feed_var_names
,
target_vars
=
target_vars
,
executor
=
exe
,
main_program
=
infer_prog
,
params_filename
=
"__params__"
)
def
main
():
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
else
:
raise
ValueError
(
"'architecture' not specified in config file."
)
merge_config
(
FLAGS
.
opt
)
# Use CPU for exporting inference model instead of GPU
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
model
=
create
(
main_arch
)
startup_prog
=
fluid
.
Program
()
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
inputs_def
=
cfg
[
'TestReader'
][
'inputs_def'
]
inputs_def
[
'use_dataloader'
]
=
False
feed_vars
,
_
=
model
.
build_inputs
(
**
inputs_def
)
test_fetches
=
model
.
test
(
feed_vars
)
infer_prog
=
infer_prog
.
clone
(
True
)
pruned_params
=
FLAGS
.
pruned_params
assert
(
FLAGS
.
pruned_params
is
not
None
),
"FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
pruned_params
=
FLAGS
.
pruned_params
.
strip
().
split
(
","
)
logger
.
info
(
"pruned params: {}"
.
format
(
pruned_params
))
pruned_ratios
=
[
float
(
n
)
for
n
in
FLAGS
.
pruned_ratios
.
strip
().
split
(
","
)]
logger
.
info
(
"pruned ratios: {}"
.
format
(
pruned_ratios
))
assert
(
len
(
pruned_params
)
==
len
(
pruned_ratios
)
),
"The length of pruned params and pruned ratios should be equal."
assert
(
pruned_ratios
>
[
0
]
*
len
(
pruned_ratios
)
and
pruned_ratios
<
[
1
]
*
len
(
pruned_ratios
)
),
"The elements of pruned ratios should be in range (0, 1)."
base_flops
=
flops
(
infer_prog
)
pruner
=
Pruner
()
infer_prog
,
_
,
_
=
pruner
.
prune
(
infer_prog
,
fluid
.
global_scope
(),
params
=
pruned_params
,
ratios
=
pruned_ratios
,
place
=
place
,
only_graph
=
True
)
pruned_flops
=
flops
(
infer_prog
)
logger
.
info
(
"pruned FLOPS: {}"
.
format
(
float
(
base_flops
-
pruned_flops
)
/
base_flops
))
exe
.
run
(
startup_prog
)
checkpoint
.
load_checkpoint
(
exe
,
infer_prog
,
cfg
.
weights
)
save_infer_model
(
FLAGS
,
exe
,
feed_vars
,
test_fetches
,
infer_prog
)
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Directory for storing the output model files."
)
parser
.
add_argument
(
"-p"
,
"--pruned_params"
,
default
=
None
,
type
=
str
,
help
=
"The parameters to be pruned when calculating sensitivities."
)
parser
.
add_argument
(
"--pruned_ratios"
,
default
=
None
,
type
=
str
,
help
=
"The ratios pruned iteratively for each parameter when calculating sensitivities."
)
FLAGS
=
parser
.
parse_args
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录