未验证 提交 a784e4fe 编写于 作者: W whs 提交者: GitHub

Fix demo of pruning to load pretrained model. (#115)

上级 eac4f3b2
...@@ -17,7 +17,20 @@ ...@@ -17,7 +17,20 @@
1). 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。 1). 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。
2). 使用`train.py`脚本时,指定`--data`选项为`imagenet`. 2). 使用`train.py`脚本时,指定`--data`选项为`imagenet`.
## 2. 启动剪裁任务 ## 2. 下载预训练模型
如果使用`ImageNet`数据,建议在预训练模型的基础上进行剪裁,请从[分类库](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载合适的预训练模型。
这里以`MobileNetV1`为例,下载并解压预训练模型到当前路径:
```
wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
tar -xf MobileNetV1_pretrained.tar
```
使用`train.py`脚本时,指定`--pretrained_model`加载预训练模型。
## 3. 启动剪裁任务
通过以下命令启动裁剪任务: 通过以下命令启动裁剪任务:
...@@ -25,8 +38,8 @@ ...@@ -25,8 +38,8 @@
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python train.py \ python train.py \
--model "MobileNet" \ --model "MobileNet" \
--pruned_ratio 0.33 \ --pruned_ratio 0.31 \
--data "imagenet" --data "mnist"
``` ```
其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。 其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。
...@@ -35,7 +48,7 @@ python train.py \ ...@@ -35,7 +48,7 @@ python train.py \
在本示例中,会在日志中输出剪裁前后的`FLOPs`,并且每训练一轮就会保存一个模型到文件系统。 在本示例中,会在日志中输出剪裁前后的`FLOPs`,并且每训练一轮就会保存一个模型到文件系统。
## 3. 加载和评估模型 ## 4. 加载和评估模型
本节介绍如何加载训练过程中保存的模型。 本节介绍如何加载训练过程中保存的模型。
...@@ -43,14 +56,14 @@ python train.py \ ...@@ -43,14 +56,14 @@ python train.py \
``` ```
python eval.py \ python eval.py \
--model "mobilenet" \ --model "MobileNet" \
--data "mnist" \ --data "mnist" \
--model_path "./models/0" --model_path "./models/0"
``` ```
在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。 在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。
## 4. 接口介绍 ## 5. 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) 该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
......
...@@ -68,7 +68,7 @@ def eval(args): ...@@ -68,7 +68,7 @@ def eval(args):
val_feeder = feeder = fluid.DataFeeder( val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program) [image, label], place, program=val_program)
load_model(val_program, "./model/mobilenetv1_prune_50") load_model(exe, val_program, args.model_path)
batch_id = 0 batch_id = 0
acc_top1_ns = [] acc_top1_ns = []
......
...@@ -136,6 +136,8 @@ def compress(args): ...@@ -136,6 +136,8 @@ def compress(args):
return os.path.exists( return os.path.exists(
os.path.join(args.pretrained_model, var.name)) os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size) val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
...@@ -200,6 +202,8 @@ def compress(args): ...@@ -200,6 +202,8 @@ def compress(args):
end_time - start_time)) end_time - start_time))
batch_id += 1 batch_id += 1
test(0, val_program)
params = get_pruned_params(args, fluid.default_main_program()) params = get_pruned_params(args, fluid.default_main_program())
_logger.info("FLOPs before pruning: {}".format( _logger.info("FLOPs before pruning: {}".format(
flops(fluid.default_main_program()))) flops(fluid.default_main_program())))
......
...@@ -378,7 +378,7 @@ load_sensitivities ...@@ -378,7 +378,7 @@ load_sensitivities
} }
} }
sensitivities_file = "sensitive_api_demo.data" sensitivities_file = "sensitive_api_demo.data"
with open(sensitivities_file, 'w') as f: with open(sensitivities_file, 'wb') as f:
pickle.dump(sen, f) pickle.dump(sen, f)
sensitivities = load_sensitivities(sensitivities_file) sensitivities = load_sensitivities(sensitivities_file)
print(sensitivities) print(sensitivities)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册