未验证 提交 f5829b03 编写于 作者: W whs 提交者: GitHub

Refine the document of pruning demo. (#65)

上级 ff1ff03a
...@@ -14,8 +14,7 @@ DATA_DIM = 224 ...@@ -14,8 +14,7 @@ DATA_DIM = 224
THREAD = 16 THREAD = 16
BUF_SIZE = 10240 BUF_SIZE = 10240
#DATA_DIR = './data/ILSVRC2012/' DATA_DIR = './data/ILSVRC2012/'
DATA_DIR = './data/'
DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR) DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
......
# 卷积通道剪裁示例 # 图像分类模型卷积层通道剪裁示例
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。 本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。
...@@ -9,39 +9,33 @@ ...@@ -9,39 +9,33 @@
- ResNet50 - ResNet50
- PVANet - PVANet
## 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) ## 1. 数据准备
## 确定待裁参数
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:
``` 本示例支持`MNIST``ImageNet`两种数据。默认情况下,会自动下载并使用`MNIST`数据,如果需要使用`ImageNet`数据,请按以下步骤操作:
for param in program.global_block().all_parameters():
print("param name: {}; shape: {}".format(param.name, param.shape))
```
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。 1). 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。
2). 使用`train.py`脚本时,指定`--data`选项为`imagenet`.
## 启动裁剪任务 ## 2. 启动剪裁任务
通过以下命令启动裁剪任务: 通过以下命令启动裁剪任务:
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python train.py python train.py \
--model "MobileNet" \
--pruned_ratio 0.33 \
--data "imagenet"
``` ```
在本示例中,每训练一轮就会保存一个模型到文件系统 其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集
执行`python train.py --help`查看更多选项。 执行`python train.py --help`查看更多选项。
## 注意 在本示例中,会在日志中输出剪裁前后的`FLOPs`,并且每训练一轮就会保存一个模型到文件系统。
1. 在接口`paddle.Pruner.prune`的参数中,`params``ratios`的长度需要一样。 ## 3. 加载和评估模型
## 加载和评估模型
本节介绍如何加载训练过程中保存的模型。 本节介绍如何加载训练过程中保存的模型。
...@@ -55,3 +49,10 @@ python eval.py \ ...@@ -55,3 +49,10 @@ python eval.py \
``` ```
在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。 在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。
## 4. 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
在调用`paddleslim.Pruner`工具类时,需要指定待裁卷积层的参数名称。不同模型的参数命名不同,
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。
...@@ -36,6 +36,7 @@ add_arg('data', str, "mnist", "Which data to use. 'm ...@@ -36,6 +36,7 @@ add_arg('data', str, "mnist", "Which data to use. 'm
add_arg('log_period', int, 10, "Log period in batches.") add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.") add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.") add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -207,7 +208,7 @@ def compress(args): ...@@ -207,7 +208,7 @@ def compress(args):
val_program, val_program,
fluid.global_scope(), fluid.global_scope(),
params=params, params=params,
ratios=[0.33] * len(params), ratios=[FLAGS.pruned_ratio] * len(params),
place=place, place=place,
only_graph=True) only_graph=True)
...@@ -215,7 +216,7 @@ def compress(args): ...@@ -215,7 +216,7 @@ def compress(args):
fluid.default_main_program(), fluid.default_main_program(),
fluid.global_scope(), fluid.global_scope(),
params=params, params=params,
ratios=[0.33] * len(params), ratios=[FLAGS.pruned_ratio] * len(params),
place=place) place=place)
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program))) _logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
for i in range(args.num_epochs): for i in range(args.num_epochs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册