未验证 提交 1cb8d1bd 编写于 作者: W whs 提交者: GitHub

Refine quick start tutorial of pruning. (#49)

上级 648978cd
from __future__ import absolute_import
from .mobilenet import MobileNet from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50 from .resnet import ResNet34, ResNet50
from .resnet_vd import ResNet50_vd from .resnet_vd import ResNet50_vd
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .pvanet import PVANet from .pvanet import PVANet
__all__ = [ __all__ = [
"model_list", "MobileNet", "ResNet34", "ResNet50", "MobileNetV2", "PVANet",
"ResNet50_vd"
]
model_list = [
'MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet', 'ResNet50_vd' 'MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet', 'ResNet50_vd'
] ]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 图像分类模型通道剪裁-快速开始\n",
"\n",
"该教程以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的卷积通道剪裁接口]()。\n",
"该示例包含以下步骤:\n",
"\n",
"1. 导入依赖\n",
"2. 构建模型\n",
"3. 剪裁\n",
"4. 训练剪裁后的模型\n",
"\n",
"以下章节依次次介绍每个步骤的内容。\n",
"\n",
"## 1. 导入依赖\n",
"\n",
"PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.fluid as fluid\n",
"import paddleslim as slim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 构建网络\n",
"\n",
"该章节构造一个用于对MNIST数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[1, 28, 28]`,输出类别数为10。\n",
"为了方便展示示例,我们在`paddleslim.models`下预定义了用于构建分类模型的方法,执行以下代码构建分类模型:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"exe, train_program, val_program, inputs, outputs = slim.models.image_classification(\"MobileNet\", [1, 28, 28], 10, use_gpu=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
">注意:paddleslim.models下的API并非PaddleSlim常规API,是为了简化示例而封装预定义的一系列方法,比如:模型结构的定义、Program的构建等。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 剪裁卷积层通道\n",
"\n",
"### 3.1 计算剪裁之前的FLOPs"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FLOPs: 10907072.0\n"
]
}
],
"source": [
"FLOPs = slim.analysis.flops(train_program)\n",
"print(\"FLOPs: {}\".format(FLOPs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 剪裁\n",
"\n",
"我们这里对参数名为`conv2_1_sep_weights`和`conv2_2_sep_weights`的卷积层进行剪裁,分别剪掉20%和30%的通道数。\n",
"代码如下所示:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"pruner = slim.prune.Pruner()\n",
"pruned_program, _, _ = pruner.prune(\n",
" train_program,\n",
" fluid.global_scope(),\n",
" params=[\"conv2_1_sep_weights\", \"conv2_2_sep_weights\"],\n",
" ratios=[0.33] * 2,\n",
" place=fluid.CPUPlace())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上操作会修改`train_program`中对应卷积层参数的定义,同时对`fluid.global_scope()`中存储的参数数组进行裁剪。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.3 计算剪裁之后的FLOPs"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FLOPs: 10907072.0\n"
]
}
],
"source": [
"FLOPs = paddleslim.analysis.flops(train_program)\n",
"print(\"FLOPs: {}\".format(FLOPs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 训练剪裁后的模型\n",
"\n",
"### 4.1 定义输入数据\n",
"\n",
"为了快速执行该示例,我们选取简单的MNIST数据,Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。\n",
"代码如下:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import paddle.dataset.mnist as reader\n",
"train_reader = paddle.batch(\n",
" reader.train(), batch_size=128, drop_last=True)\n",
"train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 执行训练\n",
"以下代码执行了一个`epoch`的训练:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.1484375] [0.4921875] [2.6727316]\n",
"[0.125] [0.546875] [2.6547904]\n",
"[0.125] [0.5546875] [2.795205]\n",
"[0.1171875] [0.578125] [2.8561475]\n",
"[0.1875] [0.59375] [2.470603]\n",
"[0.1796875] [0.578125] [2.8031898]\n",
"[0.1484375] [0.6015625] [2.7530417]\n",
"[0.1953125] [0.640625] [2.711596]\n",
"[0.125] [0.59375] [2.8637898]\n",
"[0.1796875] [0.53125] [2.9473038]\n",
"[0.25] [0.671875] [2.3943179]\n",
"[0.25] [0.6953125] [2.632146]\n",
"[0.2578125] [0.7265625] [2.723265]\n",
"[0.359375] [0.765625] [2.4263484]\n",
"[0.3828125] [0.8203125] [2.226284]\n",
"[0.421875] [0.8203125] [1.8042578]\n",
"[0.4765625] [0.890625] [1.6841211]\n",
"[0.53125] [0.8671875] [2.1971617]\n",
"[0.5546875] [0.8984375] [1.5361531]\n",
"[0.53125] [0.890625] [1.7211896]\n",
"[0.5078125] [0.8984375] [1.6586945]\n",
"[0.53125] [0.9140625] [1.8980236]\n",
"[0.546875] [0.9453125] [1.5279069]\n",
"[0.5234375] [0.8828125] [1.7356458]\n",
"[0.6015625] [0.9765625] [1.0375824]\n",
"[0.5546875] [0.921875] [1.639497]\n",
"[0.6015625] [0.9375] [1.5469061]\n",
"[0.578125] [0.96875] [1.3573356]\n",
"[0.65625] [0.9453125] [1.3787829]\n",
"[0.640625] [0.9765625] [0.9946856]\n",
"[0.65625] [0.96875] [1.1651027]\n",
"[0.625] [0.984375] [1.0487883]\n",
"[0.7265625] [0.9609375] [1.2526855]\n",
"[0.7265625] [0.9765625] [1.2954011]\n",
"[0.65625] [0.96875] [1.1181556]\n",
"[0.71875] [0.9765625] [0.97891223]\n",
"[0.640625] [0.9609375] [1.2135172]\n",
"[0.7265625] [0.9921875] [0.8950747]\n",
"[0.7578125] [0.96875] [1.0864108]\n",
"[0.734375] [0.9921875] [0.8392239]\n",
"[0.796875] [0.9609375] [0.7012155]\n",
"[0.7734375] [0.9765625] [0.7409136]\n",
"[0.8046875] [0.984375] [0.6108341]\n",
"[0.796875] [0.9765625] [0.63867176]\n",
"[0.7734375] [0.984375] [0.64099216]\n",
"[0.7578125] [0.9453125] [0.83827704]\n",
"[0.8046875] [0.9921875] [0.5311729]\n",
"[0.8984375] [0.9921875] [0.36445504]\n",
"[0.859375] [0.9921875] [0.40577835]\n",
"[0.8125] [0.9765625] [0.64629185]\n",
"[0.84375] [1.] [0.38400555]\n",
"[0.890625] [0.9765625] [0.45866236]\n",
"[0.8828125] [0.9921875] [0.3711415]\n",
"[0.7578125] [0.9921875] [0.6650479]\n",
"[0.7578125] [0.984375] [0.9030752]\n",
"[0.8671875] [0.9921875] [0.3678714]\n",
"[0.7421875] [0.9765625] [0.7424855]\n",
"[0.7890625] [1.] [0.6212543]\n",
"[0.8359375] [1.] [0.58529043]\n",
"[0.8203125] [0.96875] [0.5860813]\n",
"[0.8671875] [0.9921875] [0.415236]\n",
"[0.8125] [1.] [0.60501564]\n",
"[0.796875] [0.9765625] [0.60677457]\n",
"[0.8515625] [1.] [0.5338207]\n",
"[0.8046875] [0.9921875] [0.54180473]\n",
"[0.875] [0.9921875] [0.7293667]\n",
"[0.84375] [0.9765625] [0.5581689]\n",
"[0.8359375] [1.] [0.50712734]\n",
"[0.8671875] [0.9921875] [0.55217856]\n",
"[0.765625] [0.96875] [0.8076792]\n",
"[0.953125] [1.] [0.17031987]\n",
"[0.890625] [0.9921875] [0.42383268]\n",
"[0.828125] [0.9765625] [0.49300486]\n",
"[0.8671875] [0.96875] [0.57985115]\n",
"[0.8515625] [1.] [0.4901033]\n",
"[0.921875] [1.] [0.34583277]\n",
"[0.8984375] [0.984375] [0.41139168]\n",
"[0.9296875] [1.] [0.20420414]\n",
"[0.921875] [0.984375] [0.24322833]\n",
"[0.921875] [0.9921875] [0.30570173]\n",
"[0.875] [0.9921875] [0.3866225]\n",
"[0.9140625] [0.9921875] [0.20813875]\n",
"[0.9140625] [1.] [0.17933217]\n",
"[0.8984375] [0.9921875] [0.32508463]\n",
"[0.9375] [1.] [0.24799153]\n",
"[0.9140625] [1.] [0.26146784]\n",
"[0.90625] [1.] [0.24672262]\n",
"[0.8828125] [1.] [0.34094217]\n",
"[0.90625] [1.] [0.2964819]\n",
"[0.9296875] [1.] [0.18237087]\n",
"[0.84375] [1.] [0.7182543]\n",
"[0.8671875] [0.984375] [0.508474]\n",
"[0.8828125] [0.9921875] [0.367172]\n",
"[0.9453125] [1.] [0.2366665]\n",
"[0.9375] [1.] [0.12494276]\n",
"[0.8984375] [1.] [0.3395289]\n",
"[0.890625] [0.984375] [0.30877113]\n",
"[0.90625] [1.] [0.29763448]\n",
"[0.8828125] [0.984375] [0.4845504]\n",
"[0.8515625] [1.] [0.45548072]\n",
"[0.8828125] [1.] [0.33331633]\n",
"[0.90625] [1.] [0.4024018]\n",
"[0.890625] [0.984375] [0.73405886]\n",
"[0.9609375] [0.9921875] [0.15409982]\n",
"[0.9140625] [0.984375] [0.37103674]\n",
"[0.953125] [1.] [0.17628372]\n",
"[0.890625] [1.] [0.36522508]\n",
"[0.8828125] [1.] [0.407708]\n",
"[0.9375] [0.984375] [0.25090045]\n",
"[0.890625] [0.984375] [0.35742313]\n",
"[0.921875] [0.9921875] [0.2751101]\n",
"[0.890625] [0.984375] [0.43053097]\n",
"[0.875] [0.9921875] [0.34412643]\n",
"[0.90625] [1.] [0.35595697]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-21-92f72657bddc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpruned_program\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_feeder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m 779\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 780\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEOFException\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_impl\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 829\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 830\u001b[0m \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 831\u001b[0;31m use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m 832\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 833\u001b[0m \u001b[0mprogram\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_program\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0muse_program_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 904\u001b[0m self._default_executor.run(program.desc, scope, 0, True, True,\n\u001b[0;32m--> 905\u001b[0;31m fetch_var_name)\n\u001b[0m\u001b[1;32m 906\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 907\u001b[0m self._default_executor.run_prepared_ctx(ctx, scope, False, False,\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for data in train_reader():\n",
" acc1, acc5, loss = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs)\n",
" print(acc1, acc5, loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# 卷积通道剪裁示例
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。
当前示例支持以下分类模型:
- MobileNetV1
- MobileNetV2
- ResNet50
- PVANet
## 接口介绍
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
## 确定待裁参数
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:
```python
for param in program.global_block().all_parameters():
print("param name: {}; shape: {}".format(param.name, param.shape))
```
`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。
## 启动裁剪任务
通过以下命令启动裁剪任务:
```python
export CUDA_VISIBLE_DEVICES=0
python train.py
```
执行`python train.py --help`查看更多选项。
## 注意
1. 在接口`paddle.Pruner.prune`的参数中,`params``ratios`的长度需要一样。
# 图像分类模型通道剪裁-快速开始
该教程以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的卷积通道剪裁接口]()。
该示例包含以下步骤:
1. 导入依赖
2. 构建模型
3. 剪裁
4. 训练剪裁后的模型
以下章节依次次介绍每个步骤的内容。
## 1. 导入依赖
PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:
```
import paddle
import paddle.fluid as fluid
import paddleslim as slim
```
## 2. 构建网络
该章节构造一个用于对MNIST数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[1, 28, 28]`,输出类别数为10。
为了方便展示示例,我们在`paddleslim.models`下预定义了用于构建分类模型的方法,执行以下代码构建分类模型:
```
exe, train_program, val_program, inputs, outputs =
slim.models.image_classification("MobileNet", [1, 28, 28], 10, use_gpu=False)
```
>注意:paddleslim.models下的API并非PaddleSlim常规API,是为了简化示例而封装预定义的一系列方法,比如:模型结构的定义、Program的构建等。
## 3. 剪裁卷积层通道
### 3.1 计算剪裁之前的FLOPs
```
FLOPs = slim.analysis.flops(train_program)
print("FLOPs: {}".format(FLOPs))
```
### 3.2 剪裁
我们这里对参数名为`conv2_1_sep_weights``conv2_2_sep_weights`的卷积层进行剪裁,分别剪掉20%和30%的通道数。
代码如下所示:
```
pruner = slim.prune.Pruner()
pruned_program, _, _ = pruner.prune(
train_program,
fluid.global_scope(),
params=["conv2_1_sep_weights", "conv2_2_sep_weights"],
ratios=[0.33] * 2,
place=fluid.CPUPlace())
```
以上操作会修改`train_program`中对应卷积层参数的定义,同时对`fluid.global_scope()`中存储的参数数组进行裁剪。
### 3.3 计算剪裁之后的FLOPs
```
FLOPs = paddleslim.analysis.flops(train_program)
print("FLOPs: {}".format(FLOPs))
```
## 4. 训练剪裁后的模型
### 4.1 定义输入数据
为了快速执行该示例,我们选取简单的MNIST数据,Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。
代码如下:
```
import paddle.dataset.mnist as reader
train_reader = paddle.batch(
reader.train(), batch_size=128, drop_last=True)
train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())
```
### 4.2 执行训练
以下代码执行了一个`epoch`的训练:
```
for data in train_reader():
acc1, acc5, loss = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs)
print(acc1, acc5, loss)
```
...@@ -4,6 +4,7 @@ nav: ...@@ -4,6 +4,7 @@ nav:
- Home: index.md - Home: index.md
- 模型库: model_zoo.md - 模型库: model_zoo.md
- 教程: - 教程:
- 图像分类模型通道剪裁-快速开始: tutorials/pruning_tutorial.md
- 离线量化: tutorials/quant_post_demo.md - 离线量化: tutorials/quant_post_demo.md
- 量化训练: tutorials/quant_aware_demo.md - 量化训练: tutorials/quant_aware_demo.md
- Embedding量化: tutorials/quant_embedding_demo.md - Embedding量化: tutorials/quant_embedding_demo.md
......
...@@ -11,3 +11,12 @@ ...@@ -11,3 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from paddleslim import models
from paddleslim import prune
from paddleslim import nas
from paddleslim import analysis
from paddleslim import dist
from paddleslim import quant
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant']
from __future__ import absolute_import
import paddle.fluid as fluid
from ..models import classification_models
__all__ = ["image_classification"]
model_list = classification_models.model_list
def image_classification(model, image_shape, class_num, use_gpu=False):
assert model in model_list
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = classification_models.__dict__[model]()
out = model.net(input=image, class_dim=class_num)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = fluid.optimizer.Momentum(0.1, 0.9)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
return exe, train_program, val_program, (image, label), (
acc_top1.name, acc_top5.name, avg_cost.name)
...@@ -11,18 +11,20 @@ ...@@ -11,18 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from .pruner import * from .pruner import *
import pruner from ..prune import pruner
from .auto_pruner import * from .auto_pruner import *
import auto_pruner import auto_pruner
from .sensitive_pruner import * from .sensitive_pruner import *
import sensitive_pruner from ..prune import sensitive_pruner
from .sensitive import * from .sensitive import *
import sensitive from ..prune import sensitive
from prune_walker import * from .prune_walker import *
import prune_walker from ..prune import prune_walker
from io import * from io import *
import io from ..prune import io
__all__ = [] __all__ = []
......
...@@ -236,7 +236,7 @@ def get_ratios_by_loss(sensitivities, loss): ...@@ -236,7 +236,7 @@ def get_ratios_by_loss(sensitivities, loss):
ratio = r0 + (loss - l0) * (r1 - r0) / (l1 - l0) ratio = r0 + (loss - l0) * (r1 - r0) / (l1 - l0)
ratios[param] = ratio ratios[param] = ratio
if ratio > 1: if ratio > 1:
print losses, ratio, (r1 - r0) / (l1 - l0), i _logger.info(losses, ratio, (r1 - r0) / (l1 - l0), i)
break break
return ratios return ratios
...@@ -64,7 +64,7 @@ class SensitivePruner(object): ...@@ -64,7 +64,7 @@ class SensitivePruner(object):
exe = fluid.Executor(self._place) exe = fluid.Executor(self._place)
checkpoints = self._checkpoints if checkpoints is None else checkpoints checkpoints = self._checkpoints if checkpoints is None else checkpoints
print("check points: {}".format(checkpoints)) _logger.info("check points: {}".format(checkpoints))
main_program = None main_program = None
eval_program = None eval_program = None
if checkpoints is not None: if checkpoints is not None:
...@@ -87,8 +87,9 @@ class SensitivePruner(object): ...@@ -87,8 +87,9 @@ class SensitivePruner(object):
with fluid.scope_guard(self._scope): with fluid.scope_guard(self._scope):
fluid.io.load_persistables(exe, latest_ck_path, fluid.io.load_persistables(exe, latest_ck_path,
main_program, "__params__") main_program, "__params__")
print("load checkpoint from: {}".format(latest_ck_path)) _logger.info("load checkpoint from: {}".format(latest_ck_path))
print("flops of eval program: {}".format(flops(eval_program))) _logger.info("flops of eval program: {}".format(
flops(eval_program)))
return main_program, eval_program, self._iter return main_program, eval_program, self._iter
def greedy_prune(self, def greedy_prune(self,
...@@ -108,7 +109,7 @@ class SensitivePruner(object): ...@@ -108,7 +109,7 @@ class SensitivePruner(object):
self._eval_func, self._eval_func,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
pruned_flops_rate=pruned_flops_rate) pruned_flops_rate=pruned_flops_rate)
print sensitivities _logger.info(sensitivities)
params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk) params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk)
_logger.info("Pruning: {} by {}".format(params, ratios)) _logger.info("Pruning: {} by {}".format(params, ratios))
...@@ -152,7 +153,7 @@ class SensitivePruner(object): ...@@ -152,7 +153,7 @@ class SensitivePruner(object):
self._eval_func, self._eval_func,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
step_size=0.1) step_size=0.1)
print sensitivities _logger.info(sensitivities)
_, ratios = self.get_ratios_by_sensitive(sensitivities, pruned_flops, _, ratios = self.get_ratios_by_sensitive(sensitivities, pruned_flops,
eval_program) eval_program)
......
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
# limitations under the License. # limitations under the License.
""" PaddleSlim version string """ """ PaddleSlim version string """
__all__ = ["slim_version"] __all__ = ["slim_version"]
slim_version = "0.1" slim_version = "1.0.0"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册