未验证 提交 87a3ffd8 编写于 作者: W whs 提交者: GitHub

Fix pruning tutorials for 2.0 API (#579)

上级 31f08bdb
...@@ -18,6 +18,7 @@ PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以 ...@@ -18,6 +18,7 @@ PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddleslim as slim import paddleslim as slim
paddle.enable_static()
``` ```
## 2. 构建网络 ## 2. 构建网络
...@@ -61,7 +62,7 @@ pruned_program, _, _ = pruner.prune( ...@@ -61,7 +62,7 @@ pruned_program, _, _ = pruner.prune(
### 3.3 计算剪裁之后的FLOPs ### 3.3 计算剪裁之后的FLOPs
``` ```
FLOPs = paddleslim.analysis.flops(pruned_program) FLOPs = slim.analysis.flops(pruned_program)
print("FLOPs: {}".format(FLOPs)) print("FLOPs: {}".format(FLOPs))
``` ```
...@@ -84,6 +85,6 @@ train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace()) ...@@ -84,6 +85,6 @@ train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())
``` ```
for data in train_reader(): for data in train_reader():
acc1, acc5, loss = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs) acc1, acc5, loss, _ = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs)
print(acc1, acc5, loss) print(acc1, acc5, loss)
``` ```
...@@ -23,6 +23,7 @@ PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以 ...@@ -23,6 +23,7 @@ PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddleslim as slim import paddleslim as slim
paddle.enable_static()
``` ```
## 2. 构建网络 ## 2. 构建网络
...@@ -62,7 +63,7 @@ def test(program): ...@@ -62,7 +63,7 @@ def test(program):
acc_top1_ns = [] acc_top1_ns = []
acc_top5_ns = [] acc_top5_ns = []
for data in test_reader(): for data in test_reader():
acc_top1_n, acc_top5_n, _ = exe.run( acc_top1_n, acc_top5_n, _, _ = exe.run(
program, program,
feed=data_feeder.feed(data), feed=data_feeder.feed(data),
fetch_list=outputs) fetch_list=outputs)
...@@ -258,7 +259,7 @@ test(pruned_val_program) ...@@ -258,7 +259,7 @@ test(pruned_val_program)
```python ```python
for data in train_reader(): for data in train_reader():
acc1, acc5, loss = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs) acc1, acc5, loss, _ = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs)
print(np.mean(acc1), np.mean(acc5), np.mean(loss)) print(np.mean(acc1), np.mean(acc5), np.mean(loss))
``` ```
......
# 图像分类模型通道剪裁-敏感度分析
该教程以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的敏感度分析接口](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#sensitivity)
该示例包含以下步骤:
1. 导入依赖
2. 构建模型
3. 定义输入数据
4. 定义模型评估方法
5. 训练模型
6. 获取待分析卷积参数名称
7. 分析敏感度
8. 剪裁模型
以下章节依次次介绍每个步骤的内容。
## 1. 导入依赖
PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:
```
import paddle
import paddle.fluid as fluid
import paddleslim as slim
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册