From e935dd854e27d544ced6e2b2ba5ad33f9977c716 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 20 May 2021 13:56:59 +0800 Subject: [PATCH] Fix code in some docs (#763) * Fix code in docs * Fix code in docs --- .../quick_start/distillation_tutorial_en.md | 3 +- docs/en/quick_start/nas_tutorial_en.md | 2 ++ ...cation_sensitivity_analysis_tutorial_en.md | 33 +++++++++-------- .../dygraph/pruners/fpgm_filter_pruner.rst | 12 +++++-- .../dygraph/pruners/l1norm_filter_pruner.rst | 15 ++++---- .../dygraph/pruners/l2norm_filter_pruner.rst | 17 ++++----- .../api_cn/static/common/analysis_api.rst | 7 ++-- .../static/dist/single_distiller_api.rst | 4 ++- docs/zh_cn/api_cn/static/nas/nas_api.rst | 14 ++++---- docs/zh_cn/api_cn/static/prune/prune_api.rst | 16 ++++++--- .../static/distillation_tutorial.md | 2 +- .../pruning/dygraph/filter_pruning.md | 2 +- .../dygraph/self_defined_filter_pruning.md | 4 +-- ...ification_sensitivity_analysis_tutorial.md | 36 +++++++++---------- .../quant/static/embedding_quant_tutorial.md | 32 ++++++++++++++--- 15 files changed, 124 insertions(+), 75 deletions(-) diff --git a/docs/en/quick_start/distillation_tutorial_en.md b/docs/en/quick_start/distillation_tutorial_en.md index 8da2dcf2..45d3b2ea 100755 --- a/docs/en/quick_start/distillation_tutorial_en.md +++ b/docs/en/quick_start/distillation_tutorial_en.md @@ -13,10 +13,11 @@ by a demo of MobileNetV1 model on MNIST dataset. This tutorial following workflo PaddleSlim dependents on Paddle1.7. Please ensure that you have installed paddle correctly. Import Paddle and PaddleSlim as below: -``` +```python import paddle import paddle.fluid as fluid import paddleslim as slim +paddle.enable_static() ``` ## 2. Define student_program and teacher_program diff --git a/docs/en/quick_start/nas_tutorial_en.md b/docs/en/quick_start/nas_tutorial_en.md index 51065bf9..3d15b92b 100644 --- a/docs/en/quick_start/nas_tutorial_en.md +++ b/docs/en/quick_start/nas_tutorial_en.md @@ -30,6 +30,8 @@ import numpy as np ``` ## 2. initial SANAS instance + +Please set a unused port when build instance of SANAS. ```python sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8337), save_checkpoint=None) ``` diff --git a/docs/en/tutorials/image_classification_sensitivity_analysis_tutorial_en.md b/docs/en/tutorials/image_classification_sensitivity_analysis_tutorial_en.md index 1df6817d..262dd15e 100644 --- a/docs/en/tutorials/image_classification_sensitivity_analysis_tutorial_en.md +++ b/docs/en/tutorials/image_classification_sensitivity_analysis_tutorial_en.md @@ -21,6 +21,7 @@ PaddleSlim dependents on Paddle1.7. Please ensure that you have installed paddle import paddle import paddle.fluid as fluid import paddleslim as slim +paddle.enable_static() ``` ## 2. Build model @@ -62,7 +63,7 @@ def test(program): acc_top1_ns = [] acc_top5_ns = [] for data in test_reader(): - acc_top1_n, acc_top5_n, _ = exe.run( + acc_top1_n, acc_top5_n, _ , _= exe.run( program, feed=data_feeder.feed(data), fetch_list=outputs) @@ -82,7 +83,7 @@ Training model as below: ```python for data in train_reader(): - acc1, acc5, loss = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) + acc1, acc5, loss, _ = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) print(np.mean(acc1), np.mean(acc5), np.mean(loss)) ``` @@ -206,19 +207,22 @@ ratios = slim.prune.get_ratios_by_loss(s_0, loss) print(ratios) ``` -### 8.2 Pruning training network +### 8.2 Pruning test network + +Note:The `only_graph` should be set to True while pruning test network. [Pruner API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#pruner) ```python pruner = slim.prune.Pruner() -print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) -pruned_program, _, _ = pruner.prune( - train_program, +print("FLOPs before pruning: {}".format(slim.analysis.flops(val_program))) +pruned_val_program, _, _ = pruner.prune( + val_program, fluid.global_scope(), params=ratios.keys(), ratios=ratios.values(), - place=place) -print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) + place=place, + only_graph=True) +print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_val_program))) ``` ### 8.3 Pruning test network @@ -228,15 +232,14 @@ Note:The `only_graph` should be set to True while pruning test network. [Prune ```python pruner = slim.prune.Pruner() -print("FLOPs before pruning: {}".format(slim.analysis.flops(val_program))) -pruned_val_program, _, _ = pruner.prune( - val_program, +print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) +pruned_program, _, _ = pruner.prune( + train_program, fluid.global_scope(), params=ratios.keys(), ratios=ratios.values(), - place=place, - only_graph=True) -print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_val_program))) + place=place) +print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) ``` Get accuracy of pruned model on test dataset: @@ -252,7 +255,7 @@ Training pruned model: ```python for data in train_reader(): - acc1, acc5, loss = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs) + acc1, acc5, loss, _ = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs) print(np.mean(acc1), np.mean(acc5), np.mean(loss)) ``` diff --git a/docs/zh_cn/api_cn/dygraph/pruners/fpgm_filter_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/fpgm_filter_pruner.rst index 083d2f47..4759b16e 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/fpgm_filter_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/fpgm_filter_pruner.rst @@ -21,8 +21,10 @@ FPGMFilterPruner .. code-block:: python + from paddle.vision.models import mobilenet_v1 from paddleslim import FPGMFilterPruner - pruner = FPGMFilterPruner() + net = mobilenet_v1(pretrained=False) + pruner = FPGMFilterPruner(net, [1, 3, 224, 224]) .. .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") @@ -49,11 +51,12 @@ FPGMFilterPruner .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import FPGMFilterPruner net = mobilenet_v1(pretrained=False) pruner = FPGMFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_var("conv2d_26.w_0", [0]) + plan = pruner.prune_var("conv2d_26.w_0", [0], pruned_ratio=0.5) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) @@ -81,11 +84,12 @@ FPGMFilterPruner .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import FPGMFilterPruner net = mobilenet_v1(pretrained=False) pruner = FPGMFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0]) + plan = pruner.prune_vars({"conv2d_26.w_0": 0.5}, [0]) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) @@ -129,6 +133,7 @@ FPGMFilterPruner .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import FPGMFilterPruner import paddle.vision.transforms as T @@ -189,6 +194,7 @@ FPGMFilterPruner .. code-block:: python + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import FPGMFilterPruner import paddle.vision.transforms as T diff --git a/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst index cf3a05fc..e848fd1d 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst @@ -20,9 +20,10 @@ L1NormFilterPruner **示例代码:** .. code-block:: python - - from paddleslim import L1NormFilterPruner - pruner = L1NormFilterPruner() + from paddle.vision.models import mobilenet_v1 + from paddleslim import L1NormFilterPruner + net = mobilenet_v1(pretrained=False) + pruner = L1NormFilterPruner(net, [1, 3, 224, 224]) .. .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") @@ -48,7 +49,7 @@ L1NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L1NormFilterPruner net = mobilenet_v1(pretrained=False) @@ -80,7 +81,7 @@ L1NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L1NormFilterPruner net = mobilenet_v1(pretrained=False) @@ -128,7 +129,7 @@ L1NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L1NormFilterPruner import paddle.vision.transforms as T @@ -188,7 +189,7 @@ L1NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L1NormFilterPruner import paddle.vision.transforms as T diff --git a/docs/zh_cn/api_cn/dygraph/pruners/l2norm_filter_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/l2norm_filter_pruner.rst index 0400fdd3..d5527a40 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/l2norm_filter_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/l2norm_filter_pruner.rst @@ -20,9 +20,10 @@ L2NormFilterPruner **示例代码:** .. code-block:: python - + from paddle.vision.models import mobilenet_v1 from paddleslim import L2NormFilterPruner - pruner = L2NormFilterPruner() + net = mobilenet_v1(pretrained=False) + pruner = L2NormFilterPruner(net, [1, 3, 224, 224]) .. .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") @@ -48,12 +49,12 @@ L2NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L2NormFilterPruner net = mobilenet_v1(pretrained=False) pruner = L2NormFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_var("conv2d_26.w_0", [0]) + plan = pruner.prune_var("conv2d_26.w_0", [0], pruned_ratio=0.5) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) @@ -80,12 +81,12 @@ L2NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L2NormFilterPruner net = mobilenet_v1(pretrained=False) pruner = L2NormFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0]) + plan = pruner.prune_vars({"conv2d_26.w_0": 0.5}, [0]) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) @@ -128,7 +129,7 @@ L2NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L2NormFilterPruner import paddle.vision.transforms as T @@ -188,7 +189,7 @@ L2NormFilterPruner 点击 `AIStudio <>`_ 执行以下示例代码。 .. code-block:: python - + import paddle from paddle.vision.models import mobilenet_v1 from paddleslim import L2NormFilterPruner import paddle.vision.transforms as T diff --git a/docs/zh_cn/api_cn/static/common/analysis_api.rst b/docs/zh_cn/api_cn/static/common/analysis_api.rst index 0f0bebf4..6cd92f99 100644 --- a/docs/zh_cn/api_cn/static/common/analysis_api.rst +++ b/docs/zh_cn/api_cn/static/common/analysis_api.rst @@ -27,11 +27,11 @@ FLOPs **示例:** .. code-block:: python - + import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddleslim.analysis import flops - + paddle.enable_static() def conv_bn_layer(input, num_filters, filter_size, @@ -103,10 +103,11 @@ model_size .. code-block:: python + import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddleslim.analysis import model_size - + paddle.enable_static() def conv_layer(input, num_filters, filter_size, diff --git a/docs/zh_cn/api_cn/static/dist/single_distiller_api.rst b/docs/zh_cn/api_cn/static/dist/single_distiller_api.rst index 05d7c108..80489556 100644 --- a/docs/zh_cn/api_cn/static/dist/single_distiller_api.rst +++ b/docs/zh_cn/api_cn/static/dist/single_distiller_api.rst @@ -26,6 +26,7 @@ merge **使用示例:** .. code-block:: python + import paddle import paddle.fluid as fluid import paddleslim.dist as dist @@ -121,7 +122,8 @@ l2_loss **使用示例:** .. code-block:: python - import paddle + + import paddle import paddle.fluid as fluid import paddleslim.dist as dist paddle.enable_static() diff --git a/docs/zh_cn/api_cn/static/nas/nas_api.rst b/docs/zh_cn/api_cn/static/nas/nas_api.rst index 39f5a300..e3f95ee6 100644 --- a/docs/zh_cn/api_cn/static/nas/nas_api.rst +++ b/docs/zh_cn/api_cn/static/nas/nas_api.rst @@ -49,7 +49,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 from paddleslim.nas import SANAS config = [('MobileNetV2Space')] paddle.enable_static() - sanas = SANAS(configs=config, server_addr=("", 8881)) + sanas = SANAS(configs=config, , server_addr=("",8821)) .. note:: @@ -88,7 +88,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 from paddleslim.nas import SANAS config = [('MobileNetV2Space')] paddle.enable_static() - sanas = SANAS(configs=config, server_addr=("", 8882)) + sanas = SANAS(configs=config, , server_addr=("",8822)) input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') archs = sanas.next_archs() for arch in archs: @@ -142,7 +142,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 from paddleslim.nas import SANAS config = [('MobileNetV2Space')] paddle.enable_static() - sanas = SANAS(configs=config, server_addr=("", 8884)) + sanas = SANAS(configs=config, server_addr=("",8823)) input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') tokens = ([0] * 25) archs = sanas.tokens2arch(tokens)[0] @@ -233,7 +233,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 config = [('MobileNetV2Space')] paddle.enable_static() - rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8886)) + rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8824)) .. py:method:: next_archs(obs=None) @@ -255,7 +255,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 from paddleslim.nas import RLNAS config = [('MobileNetV2Space')] paddle.enable_static() - rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8887)) + rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8825)) input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') archs = rlnas.next_archs(1)[0] for arch in archs: @@ -307,7 +307,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 from paddleslim.nas import RLNAS config = [('MobileNetV2Space')] paddle.enable_static() - rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8889)) + rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8826)) archs = rlnas.final_archs(1) print(archs) @@ -330,7 +330,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 from paddleslim.nas import RLNAS config = [('MobileNetV2Space')] paddle.enable_static() - rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8891)) + rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8827)) input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') tokens = ([0] * 25) archs = rlnas.tokens2arch(tokens)[0] diff --git a/docs/zh_cn/api_cn/static/prune/prune_api.rst b/docs/zh_cn/api_cn/static/prune/prune_api.rst index e06d16c5..037e482b 100644 --- a/docs/zh_cn/api_cn/static/prune/prune_api.rst +++ b/docs/zh_cn/api_cn/static/prune/prune_api.rst @@ -38,7 +38,15 @@ Pruner - **params(list)** - 需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称: .. code-block:: python - + + import paddle + paddle.enable_static() + program = paddle.static.Program() + with paddle.static.program_guard(main_program=program): + net = paddle.vision.models.mobilenet_v1() + data = paddle.static.data(name="data", shape=[1,3,32,32]) + net(data) + for block in program.blocks: for param in block.all_parameters(): print("param: {}; shape: {}".format(param.name, param.shape)) @@ -68,11 +76,11 @@ Pruner 点击 `AIStudio `_ 执行以下示例代码。 .. code-block:: python - + import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddleslim.prune import Pruner - + paddle.enable_static() def conv_bn_layer(input, num_filters, filter_size, @@ -209,7 +217,7 @@ sensitivity from paddle.fluid.param_attr import ParamAttr from paddleslim.prune import sensitivity import paddle.dataset.mnist as reader - + paddle.enable_static() def conv_bn_layer(input, num_filters, filter_size, diff --git a/docs/zh_cn/quick_start/static/distillation_tutorial.md b/docs/zh_cn/quick_start/static/distillation_tutorial.md index 140a4695..b895cfe9 100755 --- a/docs/zh_cn/quick_start/static/distillation_tutorial.md +++ b/docs/zh_cn/quick_start/static/distillation_tutorial.md @@ -15,7 +15,7 @@ PaddleSlim依赖Paddle2.0版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim: -``` +```python import paddle import numpy as np import paddleslim as slim diff --git a/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md b/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md index 146424eb..624d68d2 100644 --- a/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md +++ b/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md @@ -7,6 +7,7 @@ PaddlePaddle提供的`vision`模块提供了一些构建好的分类模型结构,并提供在`ImageNet`数据集上的预训练模型。为了简化教程,我们不再重新定义网络结构,而是直接从`vision`模块导入模型结构。代码如下所示,我们导入`MobileNetV1`模型,并查看模型的结构信息。 ```python +from __future__ import print_function import paddle from paddle.vision.models import mobilenet_v1 net = mobilenet_v1(pretrained=False) @@ -30,7 +31,6 @@ val_dataset = paddle.vision.datasets.Cifar10(mode="test", backend="cv2",transfor 我们可以通过以下代码查看训练集和测试集的样本数量,并尝试取出训练集中的第一个样本,观察其图片的`shape`和对应的`label`。 ```python -from __future__ import print_function print(f'train samples count: {len(train_dataset)}') print(f'val samples count: {len(val_dataset)}') for data in train_dataset: diff --git a/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md b/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md index 6ae61005..d52853e9 100644 --- a/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md +++ b/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md @@ -22,7 +22,7 @@ def cal_mask(self, var_name, pruned_ratio, group): 如图1-1所示,在给定模型中有两个卷积层,第一个卷积层有3个`filters`,第二个卷积层有2个`filters`。如果删除第一个卷积绿色的`filter`,第一个卷积的输出特征图的通道数也会减1,同时需要删掉第二个卷积层绿色的`kernels`。如上所述的两个卷积共同组成一个group,表示如下: -```python +``` group = { "conv_1.weight":{ "pruned_dims": [0], @@ -94,7 +94,7 @@ class L2NormFilterPruner(FilterPruner): 如上述代码所示,我们重载了`FilterPruner`基类的`cal_mask`方法,并在`L1NormFilterPruner`代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑: -```python +``` scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) ``` diff --git a/docs/zh_cn/tutorials/pruning/static/image_classification_sensitivity_analysis_tutorial.md b/docs/zh_cn/tutorials/pruning/static/image_classification_sensitivity_analysis_tutorial.md index 6703f903..1daf8743 100644 --- a/docs/zh_cn/tutorials/pruning/static/image_classification_sensitivity_analysis_tutorial.md +++ b/docs/zh_cn/tutorials/pruning/static/image_classification_sensitivity_analysis_tutorial.md @@ -83,7 +83,7 @@ def test(program): ```python for data in train_reader(): - acc1, acc5, loss = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) + acc1, acc5, loss, _ = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) print(np.mean(acc1), np.mean(acc5), np.mean(loss)) ``` @@ -213,24 +213,9 @@ ratios = slim.prune.get_ratios_by_loss(s_0, loss) print(ratios) ``` -### 8.2 剪裁训练网络 +### 8.2 剪裁测试网络 - -```python -pruner = slim.prune.Pruner() -print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) -pruned_program, _, _ = pruner.prune( - train_program, - fluid.global_scope(), - params=ratios.keys(), - ratios=ratios.values(), - place=place) -print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) -``` - -### 8.3 剪裁测试网络 - ->注意:对测试网络进行剪裁时,需要将`only_graph`设置为True,具体原因请参考[Pruner API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/prune/prune_api.html#paddleslim.prune.Pruner) +>注意:剪裁测试网络要在剪裁训练网络之前。对测试网络进行剪裁时,需要将`only_graph`设置为True,具体原因请参考[Pruner API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/prune/prune_api.html#paddleslim.prune.Pruner) ```python @@ -252,6 +237,21 @@ print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_val_program))) test(pruned_val_program) ``` +### 8.3 剪裁训练网络 + +```python +pruner = slim.prune.Pruner() +print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) +pruned_program, _, _ = pruner.prune( + train_program, + fluid.global_scope(), + params=ratios.keys(), + ratios=ratios.values(), + place=place) +print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) +``` + + ### 8.4 训练剪裁后的模型 对剪裁后的模型在训练集上训练一个`epoch`: diff --git a/docs/zh_cn/tutorials/quant/static/embedding_quant_tutorial.md b/docs/zh_cn/tutorials/quant/static/embedding_quant_tutorial.md index 03731e09..d6c45f82 100755 --- a/docs/zh_cn/tutorials/quant/static/embedding_quant_tutorial.md +++ b/docs/zh_cn/tutorials/quant/static/embedding_quant_tutorial.md @@ -8,10 +8,34 @@ Embedding量化仅能减少模型参数的体积,并不能显著提升模型 在预测时调用paddleslim `quant_embedding`接口,主要实现代码如下: ```python -import paddleslim -place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() -exe = paddle.static.Executor(place) -main_program = paddleslim.quant.quant_embedding(main_program, place, config) +import paddle +import paddle.fluid as fluid +import paddleslim.quant as quant +paddle.enable_static() +train_program = fluid.Program() +with fluid.program_guard(train_program): + input_word = fluid.data(name="input_word", shape=[None, 1], dtype='int64') + input_emb = fluid.embedding( + input=input_word, + is_sparse=False, + size=[100, 128], + param_attr=fluid.ParamAttr(name='emb', + initializer=fluid.initializer.Uniform(-0.005, 0.005))) + +infer_program = train_program.clone(for_test=True) + +use_gpu = True +place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() +exe = fluid.Executor(place) +exe.run(fluid.default_startup_program()) + +config = { + 'quantize_op_types': ['lookup_table'], + 'lookup_table': { + 'quantize_type': 'abs_max' + } + } +quant_program = quant.quant_embedding(infer_program, place, config) ``` 详细代码与例程请参考:[Embedding量化](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_embedding) -- GitLab