Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
a784e4fe
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a784e4fe
编写于
2月 17, 2020
作者:
W
whs
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix demo of pruning to load pretrained model. (#115)
上级
eac4f3b2
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
25 addition
and
8 deletion
+25
-8
demo/prune/README.md
demo/prune/README.md
+19
-6
demo/prune/eval.py
demo/prune/eval.py
+1
-1
demo/prune/train.py
demo/prune/train.py
+4
-0
docs/zh_cn/api_cn/prune_api.rst
docs/zh_cn/api_cn/prune_api.rst
+1
-1
未找到文件。
demo/prune/README.md
浏览文件 @
a784e4fe
...
@@ -17,7 +17,20 @@
...
@@ -17,7 +17,20 @@
1). 根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
1). 根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
2). 使用
`train.py`
脚本时,指定
`--data`
选项为
`imagenet`
.
2). 使用
`train.py`
脚本时,指定
`--data`
选项为
`imagenet`
.
## 2. 启动剪裁任务
## 2. 下载预训练模型
如果使用
`ImageNet`
数据,建议在预训练模型的基础上进行剪裁,请从
[
分类库
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD
)
中下载合适的预训练模型。
这里以
`MobileNetV1`
为例,下载并解压预训练模型到当前路径:
```
wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
tar -xf MobileNetV1_pretrained.tar
```
使用
`train.py`
脚本时,指定
`--pretrained_model`
加载预训练模型。
## 3. 启动剪裁任务
通过以下命令启动裁剪任务:
通过以下命令启动裁剪任务:
...
@@ -25,8 +38,8 @@
...
@@ -25,8 +38,8 @@
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0
python train.py \
python train.py \
--model "MobileNet" \
--model "MobileNet" \
--pruned_ratio 0.3
3
\
--pruned_ratio 0.3
1
\
--data "
imagene
t"
--data "
mnis
t"
```
```
其中,
`model`
用于指定待裁剪的模型。
`pruned_ratio`
用于指定各个卷积层通道数被裁剪的比例。
`data`
选项用于指定使用的数据集。
其中,
`model`
用于指定待裁剪的模型。
`pruned_ratio`
用于指定各个卷积层通道数被裁剪的比例。
`data`
选项用于指定使用的数据集。
...
@@ -35,7 +48,7 @@ python train.py \
...
@@ -35,7 +48,7 @@ python train.py \
在本示例中,会在日志中输出剪裁前后的
`FLOPs`
,并且每训练一轮就会保存一个模型到文件系统。
在本示例中,会在日志中输出剪裁前后的
`FLOPs`
,并且每训练一轮就会保存一个模型到文件系统。
##
3
. 加载和评估模型
##
4
. 加载和评估模型
本节介绍如何加载训练过程中保存的模型。
本节介绍如何加载训练过程中保存的模型。
...
@@ -43,14 +56,14 @@ python train.py \
...
@@ -43,14 +56,14 @@ python train.py \
```
```
python eval.py \
python eval.py \
--model "
mobilen
et" \
--model "
MobileN
et" \
--data "mnist" \
--data "mnist" \
--model_path "./models/0"
--model_path "./models/0"
```
```
在脚本
`eval.py`
中,使用
`paddleslim.prune.load_model`
接口加载剪裁得到的模型。
在脚本
`eval.py`
中,使用
`paddleslim.prune.load_model`
接口加载剪裁得到的模型。
##
4
. 接口介绍
##
5
. 接口介绍
该示例使用了
`paddleslim.Pruner`
工具类,用户接口使用介绍请参考:
[
API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/prune_api/
)
该示例使用了
`paddleslim.Pruner`
工具类,用户接口使用介绍请参考:
[
API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/prune_api/
)
...
...
demo/prune/eval.py
浏览文件 @
a784e4fe
...
@@ -68,7 +68,7 @@ def eval(args):
...
@@ -68,7 +68,7 @@ def eval(args):
val_feeder
=
feeder
=
fluid
.
DataFeeder
(
val_feeder
=
feeder
=
fluid
.
DataFeeder
(
[
image
,
label
],
place
,
program
=
val_program
)
[
image
,
label
],
place
,
program
=
val_program
)
load_model
(
val_program
,
"./model/mobilenetv1_prune_50"
)
load_model
(
exe
,
val_program
,
args
.
model_path
)
batch_id
=
0
batch_id
=
0
acc_top1_ns
=
[]
acc_top1_ns
=
[]
...
...
demo/prune/train.py
浏览文件 @
a784e4fe
...
@@ -136,6 +136,8 @@ def compress(args):
...
@@ -136,6 +136,8 @@ def compress(args):
return
os
.
path
.
exists
(
return
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
pretrained_model
,
var
.
name
))
os
.
path
.
join
(
args
.
pretrained_model
,
var
.
name
))
_logger
.
info
(
"Load pretrained model from {}"
.
format
(
args
.
pretrained_model
))
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
val_reader
=
paddle
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
val_reader
=
paddle
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
...
@@ -200,6 +202,8 @@ def compress(args):
...
@@ -200,6 +202,8 @@ def compress(args):
end_time
-
start_time
))
end_time
-
start_time
))
batch_id
+=
1
batch_id
+=
1
test
(
0
,
val_program
)
params
=
get_pruned_params
(
args
,
fluid
.
default_main_program
())
params
=
get_pruned_params
(
args
,
fluid
.
default_main_program
())
_logger
.
info
(
"FLOPs before pruning: {}"
.
format
(
_logger
.
info
(
"FLOPs before pruning: {}"
.
format
(
flops
(
fluid
.
default_main_program
())))
flops
(
fluid
.
default_main_program
())))
...
...
docs/zh_cn/api_cn/prune_api.rst
浏览文件 @
a784e4fe
...
@@ -378,7 +378,7 @@ load_sensitivities
...
@@ -378,7 +378,7 @@ load_sensitivities
}
}
}
}
sensitivities_file = "sensitive_api_demo.data"
sensitivities_file = "sensitive_api_demo.data"
with open(sensitivities_file, 'w') as f:
with open(sensitivities_file, 'w
b
') as f:
pickle.dump(sen, f)
pickle.dump(sen, f)
sensitivities = load_sensitivities(sensitivities_file)
sensitivities = load_sensitivities(sensitivities_file)
print(sensitivities)
print(sensitivities)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录