未验证 提交 fceedb12 编写于 作者: W whs 提交者: GitHub

[cherry-pick/2.1.0] fix some bugs of pruning (#832)

上级 cf6f83d4
FPGMFilterPruner FPGMFilterPruner
================== ==================
.. py:class:: paddleslim.FPGMFilterPruner(model, inputs, sen_file=None) .. py:class:: paddleslim.FPGMFilterPruner(model, inputs, sen_file=None, opt=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/fpgm_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/fpgm_pruner.py>`_
用于剪裁卷积层输出通道的的剪裁器。该剪裁器按论文 `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>_` 中的统计方法对单个卷积层内的 ``Filters`` 的重要性进行排序,并按指定比例剪裁掉相对不重要的 ``Filters`` 。对 ``Filters`` 的剪裁等价于剪裁卷积层的输出通道数。 用于剪裁卷积层输出通道的的剪裁器。该剪裁器按论文 `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_ 中的统计方法对单个卷积层内的 ``Filters`` 的重要性进行排序,并按指定比例剪裁掉相对不重要的 ``Filters`` 。对 ``Filters`` 的剪裁等价于剪裁卷积层的输出通道数。
**参数:** **参数:**
...@@ -15,15 +15,40 @@ FPGMFilterPruner ...@@ -15,15 +15,40 @@ FPGMFilterPruner
- **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None - **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None
- **opt(paddle.optimizer.Optimizer)** - 动态图模型训练时用到的优化器。传入该参数是为了解决上述 ``model(paddle.nn.Layer)`` 不含有优化器,导致不能剪裁到优化器参数(例如 ``Momentum`` 中的 ``velocity`` )的问题。是否传入 ``optimizer`` 参数的逻辑为:若已经初始化了 ``optimizer`` 对象,则传入;否则,在调用 ``pruner.prune_vars()`` 之后初始化 ``optimize`` 。默认为None
**返回:** 一个剪裁器实例。 **返回:** 一个剪裁器实例。
**示例代码:** **示例代码1**
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import FPGMFilterPruner
net = mobilenet_v1(pretrained=False)
pruner = FPGMFilterPruner(net, [1, 3, 224, 224])
pruner.prune_var("conv2d_26.w_0", [0], pruned_ratio=0.5)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
..
**示例代码2**
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import FPGMFilterPruner from paddleslim import FPGMFilterPruner
pruner = FPGMFilterPruner() net = mobilenet_v1(pretrained=False)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
pruner = FPGMFilterPruner(net, [1, 3, 224, 224], opt=optimizer)
.. ..
**注意:** 上述两段代码展示了如何在 ``pruner`` 中是否调用 ``optimizer`` ,在示例代码1中,初始化 ``optimizer`` 时传入的 ``parameters`` 为剪裁后的 ``net.parameters()`` ,故无需在初始化 ``pruner`` 时传入 ``optimizer`` ;反之在示例代码2中, ``optimizer`` 中的 ``parameter`` 为剪裁前,故需要传入给 ``pruner`` 一并剪裁 ``optimizer`` 中的相关参数。
.. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive")
...@@ -120,6 +145,7 @@ FPGMFilterPruner ...@@ -120,6 +145,7 @@ FPGMFilterPruner
0.2: 0.4 0.2: 0.4
} }
} }
..
其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。 其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。
...@@ -164,7 +190,7 @@ FPGMFilterPruner ...@@ -164,7 +190,7 @@ FPGMFilterPruner
pruner = FPGMFilterPruner(net, [1, 3, 224, 224]) pruner = FPGMFilterPruner(net, [1, 3, 224, 224])
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
print(f"sen: {sen}") print(f"sen: {sen}")
..
.. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None) .. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None)
...@@ -225,6 +251,6 @@ FPGMFilterPruner ...@@ -225,6 +251,6 @@ FPGMFilterPruner
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
plan = pruner.sensitive_prune(0.5, align=8) plan = pruner.sensitive_prune(0.5, align=8)
print(f"plan: {plan}") print(f"plan: {plan}")
..
L1NormFilterPruner L1NormFilterPruner
================== ==================
.. py:class:: paddleslim.L1NormFilterPruner(model, inputs, sen_file=None) .. py:class:: paddleslim.L1NormFilterPruner(model, inputs, sen_file=None, opt=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/l1norm_pruner.py#L14>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/l1norm_pruner.py#L14>`_
...@@ -15,15 +15,40 @@ L1NormFilterPruner ...@@ -15,15 +15,40 @@ L1NormFilterPruner
- **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None - **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None
- **opt(paddle.optimizer.Optimizer)** - 动态图模型训练时用到的优化器。传入该参数是为了解决上述 ``model(paddle.nn.Layer)`` 不含有优化器,导致不能剪裁到优化器参数(例如 ``Momentum`` 中的 ``velocity`` )的问题。是否传入 ``optimizer`` 参数的逻辑为:若已经初始化了 ``optimizer`` 对象,则传入;否则,在调用了 ``pruner.prune_vars()`` 之后初始化 ``optimizer`` 。默认为None
**返回:** 一个剪裁器实例。 **返回:** 一个剪裁器实例。
**示例代码:** **示例代码1**
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner
net = mobilenet_v1(pretrained=False)
pruner = L1NormFilterPruner(net, [1, 3, 224, 224])
pruner.prune_var("conv2d_26.w_0", [0], pruned_ratio=0.5)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
..
**示例代码2**
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner from paddleslim import L1NormFilterPruner
pruner = L1NormFilterPruner() net = mobilenet_v1(pretrained=False)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
pruner = L1NormFilterPruner(net, [1, 3, 224, 224], opt=optimizer)
.. ..
**注意:** 上述两段代码展示了如何在 ``pruner`` 中是否调用 ``optimizer`` ,在示例代码1中,初始化 ``optimizer`` 时传入的 ``parameters`` 为剪裁后的 ``net.parameters()`` ,故无需在初始化 ``pruner`` 时传入 ``optimizer`` ;反之在示例代码2中, ``optimizer`` 中的 ``parameter`` 为剪裁前,故需要传入给 ``pruner`` 一并剪裁 ``optimizer`` 中的相关参数。
.. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive")
...@@ -49,6 +74,7 @@ L1NormFilterPruner ...@@ -49,6 +74,7 @@ L1NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner from paddleslim import L1NormFilterPruner
net = mobilenet_v1(pretrained=False) net = mobilenet_v1(pretrained=False)
...@@ -56,8 +82,7 @@ L1NormFilterPruner ...@@ -56,8 +82,7 @@ L1NormFilterPruner
plan = pruner.prun_var("conv2d_26.w_0", [0]) plan = pruner.prun_var("conv2d_26.w_0", [0])
print(f"plan: {plan}") print(f"plan: {plan}")
paddle.summary(net, (1, 3, 224, 224)) paddle.summary(net, (1, 3, 224, 224))
..
..
.. py:method:: prune_vars(ratios, axis, apply="impretive") .. py:method:: prune_vars(ratios, axis, apply="impretive")
...@@ -81,6 +106,7 @@ L1NormFilterPruner ...@@ -81,6 +106,7 @@ L1NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner from paddleslim import L1NormFilterPruner
net = mobilenet_v1(pretrained=False) net = mobilenet_v1(pretrained=False)
...@@ -88,7 +114,6 @@ L1NormFilterPruner ...@@ -88,7 +114,6 @@ L1NormFilterPruner
plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0]) plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0])
print(f"plan: {plan}") print(f"plan: {plan}")
paddle.summary(net, (1, 3, 224, 224)) paddle.summary(net, (1, 3, 224, 224))
.. ..
.. py:method:: sensitive(eval_func=None, sen_file=None, target_vars=None, skip_vars=[]) .. py:method:: sensitive(eval_func=None, sen_file=None, target_vars=None, skip_vars=[])
...@@ -120,7 +145,7 @@ L1NormFilterPruner ...@@ -120,7 +145,7 @@ L1NormFilterPruner
0.2: 0.4 0.2: 0.4
} }
} }
..
其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。 其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。
**示例:** **示例:**
...@@ -129,6 +154,7 @@ L1NormFilterPruner ...@@ -129,6 +154,7 @@ L1NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner from paddleslim import L1NormFilterPruner
import paddle.vision.transforms as T import paddle.vision.transforms as T
...@@ -164,7 +190,7 @@ L1NormFilterPruner ...@@ -164,7 +190,7 @@ L1NormFilterPruner
pruner = L1NormFilterPruner(net, [1, 3, 224, 224]) pruner = L1NormFilterPruner(net, [1, 3, 224, 224])
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
print(f"sen: {sen}") print(f"sen: {sen}")
..
.. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None) .. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None)
...@@ -189,6 +215,7 @@ L1NormFilterPruner ...@@ -189,6 +215,7 @@ L1NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L1NormFilterPruner from paddleslim import L1NormFilterPruner
import paddle.vision.transforms as T import paddle.vision.transforms as T
...@@ -225,6 +252,4 @@ L1NormFilterPruner ...@@ -225,6 +252,4 @@ L1NormFilterPruner
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
plan = pruner.sensitive_prune(0.5, align=8) plan = pruner.sensitive_prune(0.5, align=8)
print(f"plan: {plan}") print(f"plan: {plan}")
..
L2NormFilterPruner L2NormFilterPruner
================== ==================
.. py:class:: paddleslim.L2NormFilterPruner(model, inputs, sen_file=None) .. py:class:: paddleslim.L2NormFilterPruner(model, inputs, sen_file=None, opt=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/l2norm_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/paddleslim/dygraph/prune/l2norm_pruner.py>`_
...@@ -15,15 +15,41 @@ L2NormFilterPruner ...@@ -15,15 +15,41 @@ L2NormFilterPruner
- **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None - **sen_file(str)** - 存储敏感度信息的文件,需要指定为绝对路径。在调用当前剪裁器的 ``sensitive`` 方法时,敏感度信息会以增量的形式追加到文件 ``sen_file`` 中。如果用户不需要敏感度剪裁策略,可以将该选项设置为 ``None`` 。默认为None
- **opt(paddle.optimizer.Optimizer)** - 动态图模型训练时用到的优化器。传入该参数是为了解决上述 ``model(paddle.nn.Layer)`` 不含有优化器,导致不能剪裁到优化器参数(例如 ``Momentum`` 中的 ``velocity`` )的问题。是否传入 ``optimizer`` 参数的逻辑为:若已经初始化了 ``optimizer`` 对象,则传入;否则,在调用pruner.prune_vars()之后初始化 ``optimizer`` 。默认为None
**返回:** 一个剪裁器实例。 **返回:** 一个剪裁器实例。
**示例代码:** **示例代码1**
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import L2NormFilterPruner
net = mobilenet_v1(pretrained=False)
pruner = L2NormFilterPruner(net, [1, 3, 224, 224])
pruner.prune_var("conv2d_26.w_0", [0], pruned_ratio=0.5)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
..
**示例代码2**
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim import L2NormFilterPruner from paddleslim import L2NormFilterPruner
pruner = L2NormFilterPruner() net = mobilenet_v1(pretrained=False)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
pruner = L2NormFilterPruner(net, [1, 3, 224, 224], opt=optimizer)
.. ..
**注意:** 上述两段代码展示了如何在 ``pruner`` 中是否调用 ``optimizer`` ,在示例代码1中,初始化 ``optimizer`` 时传入的 ``parameters`` 为剪裁后的 ``net.parameters()`` ,故无需在初始化 ``pruner`` 时传入 ``optimizer`` ;反之在示例代码2中, ``optimizer`` 中的 ``parameter`` 为剪裁前,故需要传入给 ``pruner`` 一并剪裁 ``optimizer`` 中的相关参数。
.. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive") .. py:method:: prune_var(var_name, pruned_dims, pruned_ratio, apply="impretive")
...@@ -49,6 +75,7 @@ L2NormFilterPruner ...@@ -49,6 +75,7 @@ L2NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L2NormFilterPruner from paddleslim import L2NormFilterPruner
net = mobilenet_v1(pretrained=False) net = mobilenet_v1(pretrained=False)
...@@ -81,6 +108,7 @@ L2NormFilterPruner ...@@ -81,6 +108,7 @@ L2NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L2NormFilterPruner from paddleslim import L2NormFilterPruner
net = mobilenet_v1(pretrained=False) net = mobilenet_v1(pretrained=False)
...@@ -120,6 +148,8 @@ L2NormFilterPruner ...@@ -120,6 +148,8 @@ L2NormFilterPruner
0.2: 0.4 0.2: 0.4
} }
} }
..
其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。 其中,``weight_0`` 是卷积层权重变量的名称, ``sensitivities['weight_0']`` 是一个字典, key是用 ``float`` 类型数值表示的剪裁率,value是对应剪裁率下整个模型的精度损失比例。
...@@ -129,6 +159,7 @@ L2NormFilterPruner ...@@ -129,6 +159,7 @@ L2NormFilterPruner
.. code-block:: python .. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim import L2NormFilterPruner from paddleslim import L2NormFilterPruner
import paddle.vision.transforms as T import paddle.vision.transforms as T
...@@ -164,7 +195,7 @@ L2NormFilterPruner ...@@ -164,7 +195,7 @@ L2NormFilterPruner
pruner = L2NormFilterPruner(net, [1, 3, 224, 224]) pruner = L2NormFilterPruner(net, [1, 3, 224, 224])
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
print(f"sen: {sen}") print(f"sen: {sen}")
..
.. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None) .. py:method:: sensitive_prune(pruned_flops, skip_vars=[], align=None)
...@@ -225,6 +256,4 @@ L2NormFilterPruner ...@@ -225,6 +256,4 @@ L2NormFilterPruner
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle") sen = pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")
plan = pruner.sensitive_prune(0.5, align=8) plan = pruner.sensitive_prune(0.5, align=8)
print(f"plan: {plan}") print(f"plan: {plan}")
..
...@@ -73,12 +73,13 @@ FLOPs = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True) ...@@ -73,12 +73,13 @@ FLOPs = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True)
对网络模型两个不同的网络层按照参数名分别进行比例为50%,60%的裁剪。 对网络模型两个不同的网络层按照参数名分别进行比例为50%,60%的裁剪。
代码如下所示: 代码如下所示:
``` ```python
pruner = L1NormFilterPruner(net, [1, 3, 32, 32]) pruner = L1NormFilterPruner(net, [1, 3, 32, 32], opt=optimizer)
pruner.prune_vars({'conv2d_22.w_0':0.5, 'conv2d_20.w_0':0.6}, axis=0) pruner.prune_vars({'conv2d_22.w_0':0.5, 'conv2d_20.w_0':0.6}, axis=0)
``` ```
以上操作会按照网络结构中不同网路层的冗余程度对网络层进行不同程度的裁剪并修改网络模型结构。 以上操作会按照网络结构中不同网路层的冗余程度对网络层进行不同程度的裁剪并修改网络模型结构。
**注意:** 需要将`optimizer`传入`pruner`中,这是为了保证`optimizer`中的参数可以被剪裁到。例如:`momentum`中的`velocity`。但是如果在`pruner`后定义`optimizer`,则无需传入了,因为初始化`optimizer`时会指定`parameters=net.parameters()`
### 4.3 计算剪裁之后的FLOPs ### 4.3 计算剪裁之后的FLOPs
...@@ -101,7 +102,7 @@ model.evaluate(val_dataset, batch_size=128, verbose=1) ...@@ -101,7 +102,7 @@ model.evaluate(val_dataset, batch_size=128, verbose=1)
对模型进行finetune会有助于模型恢复原有精度。 对模型进行finetune会有助于模型恢复原有精度。
以下代码对裁剪过后的模型进行评估后执行了一个`epoch`的微调,再对微调过后的模型重新进行评估: 以下代码对裁剪过后的模型进行评估后执行了一个`epoch`的微调,再对微调过后的模型重新进行评估:
``` ```python
model.fit(train_dataset, epochs=1, batch_size=128, verbose=1) model.fit(train_dataset, epochs=1, batch_size=128, verbose=1)
model.evaluate(val_dataset, batch_size=128, verbose=1) model.evaluate(val_dataset, batch_size=128, verbose=1)
``` ```
...@@ -79,13 +79,15 @@ PaddleSlim提供了工具类`Pruner`来进行重要性分析和剪裁操作, ...@@ -79,13 +79,15 @@ PaddleSlim提供了工具类`Pruner`来进行重要性分析和剪裁操作,
```python ```python
from paddleslim.dygraph import L1NormFilterPruner from paddleslim.dygraph import L1NormFilterPruner
pruner = L1NormFilterPruner(net, [1, 3, 224, 224]) pruner = L1NormFilterPruner(net, [1, 3, 224, 224], opt=optimizer)
``` ```
**注意:** 需要将`optimizer`传入`pruner`中,这是为了保证`optimizer`中的参数可以被剪裁到。例如:`momentum`中的`velocity`。但是如果在`pruner`后定义`optimizer`,则无需传入了,因为初始化`optimizer`时会指定`parameters=net.parameters()`
如果本地文件系统已有一个存储敏感度信息(见4.1节)的文件,声明`L1NormFilterPruner`对象时,可以通过指定`sen_file`选项加载计算好的敏感度信息,如下: 如果本地文件系统已有一个存储敏感度信息(见4.1节)的文件,声明`L1NormFilterPruner`对象时,可以通过指定`sen_file`选项加载计算好的敏感度信息,如下:
```python ```python
#pruner = L1NormFilterPruner(net, [1, 3, 224, 224]), sen_file="./sen.pickle") #pruner = L1NormFilterPruner(net, [1, 3, 224, 224]), sen_file="./sen.pickle", opt=optimizer)
``` ```
### 4.1 卷积重要性分析 ### 4.1 卷积重要性分析
...@@ -167,13 +169,6 @@ print(f"before fine-tuning: {result}") ...@@ -167,13 +169,6 @@ print(f"before fine-tuning: {result}")
对剪裁后的模型重新训练, 并再测试集上测试精度,如下: 对剪裁后的模型重新训练, 并再测试集上测试精度,如下:
```python ```python
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
model.prepare(
optimizer,
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10) result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"after fine-tuning: {result}") print(f"after fine-tuning: {result}")
......
...@@ -70,9 +70,9 @@ from paddleslim.dygraph import FilterPruner ...@@ -70,9 +70,9 @@ from paddleslim.dygraph import FilterPruner
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, input_shape, sen_file=None, opt=None):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file) model, input_shape, sen_file=sen_file, opt=opt)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value'] value = group[var_name]['value']
...@@ -148,9 +148,9 @@ from paddleslim.dygraph import FilterPruner ...@@ -148,9 +148,9 @@ from paddleslim.dygraph import FilterPruner
class FPGMFilterPruner(FilterPruner): class FPGMFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, input_shape, sen_file=None, opt=None):
super(FPGMFilterPruner, self).__init__( super(FPGMFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file) model, input_shape, sen_file=sen_file, opt=opt)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value'] value = group[var_name]['value']
...@@ -223,7 +223,7 @@ print(result) ...@@ -223,7 +223,7 @@ print(result)
### 5.2 计算敏感度 ### 5.2 计算敏感度
```python ```python
pruner = FPGMFilterPruner(net, [1, 3, 32, 32]) pruner = FPGMFilterPruner(net, [1, 3, 32, 32], opt=optimizer)
def eval_fn(): def eval_fn():
result = model.evaluate( result = model.evaluate(
val_dataset, val_dataset,
...@@ -250,13 +250,6 @@ print(f"before fine-tuning: {result}") ...@@ -250,13 +250,6 @@ print(f"before fine-tuning: {result}")
### 5.4 重训练 ### 5.4 重训练
```python ```python
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
model.prepare(
optimizer,
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10) result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"after fine-tuning: {result}") print(f"after fine-tuning: {result}")
......
...@@ -53,14 +53,19 @@ class FilterPruner(Pruner): ...@@ -53,14 +53,19 @@ class FilterPruner(Pruner):
sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is
set rightly, 'FilterPruner::sensitive' function can not be called anymore set rightly, 'FilterPruner::sensitive' function can not be called anymore
in next step. Default: None. in next step. Default: None.
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None, opt=None,
super(FilterPruner, self).__init__(model, inputs) skip_leaves=True):
super(FilterPruner, self).__init__(model, inputs, opt=opt)
self._status = Status(sen_file) self._status = Status(sen_file)
self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning # sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections(model, inputs) self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves)
# skip vars in: # skip vars in:
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
...@@ -216,7 +221,7 @@ class FilterPruner(Pruner): ...@@ -216,7 +221,7 @@ class FilterPruner(Pruner):
plan = self.prune_vars(ratios, axis=dims) plan = self.prune_vars(ratios, axis=dims)
c_flops = flops(self.model, self.inputs) c_flops = flops(self.model, self.inputs)
c_pruned_flops = (base_flops - c_flops) / base_flops c_pruned_flops = (base_flops - c_flops) / base_flops
plan.restore(self.model) plan.restore(self.model, opt=self.opt)
_logger.debug("Seaching ratios, pruned FLOPs: {}".format( _logger.debug("Seaching ratios, pruned FLOPs: {}".format(
c_pruned_flops)) c_pruned_flops))
key = str(round(c_pruned_flops, 4)) key = str(round(c_pruned_flops, 4))
...@@ -265,7 +270,7 @@ class FilterPruner(Pruner): ...@@ -265,7 +270,7 @@ class FilterPruner(Pruner):
var_name, ratio, loss)) var_name, ratio, loss))
sensitivities[var_name][ratio] = loss sensitivities[var_name][ratio] = loss
self._status.save(status_file) self._status.save(status_file)
plan.restore(model) plan.restore(model, opt=self.opt)
return sensitivities return sensitivities
...@@ -287,7 +292,7 @@ class FilterPruner(Pruner): ...@@ -287,7 +292,7 @@ class FilterPruner(Pruner):
def restore(self): def restore(self):
if self.plan is not None: if self.plan is not None:
self.plan.restore(self.model) self.plan.restore(self.model, opt=self.opt)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
raise NotImplemented("cal_mask is not implemented") raise NotImplemented("cal_mask is not implemented")
...@@ -347,7 +352,7 @@ class FilterPruner(Pruner): ...@@ -347,7 +352,7 @@ class FilterPruner(Pruner):
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
plan.apply(self.model, lazy=False) plan.apply(self.model, lazy=False, opt=self.opt)
return plan return plan
def _transform_mask(self, mask, transform): def _transform_mask(self, mask, transform):
...@@ -365,6 +370,8 @@ class FilterPruner(Pruner): ...@@ -365,6 +370,8 @@ class FilterPruner(Pruner):
stride = transform['stride'] stride = transform['stride']
mask = mask.repeat(stride) if stride > 1 else mask mask = mask.repeat(stride) if stride > 1 else mask
return mask return mask
elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
else: else:
return mask return mask
return dst_mask return dst_mask
...@@ -12,8 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,8 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class FPGMFilterPruner(FilterPruner): class FPGMFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None, opt=None,
super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file) skip_leaves=True):
super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner): class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -16,16 +16,17 @@ class Pruner(object): ...@@ -16,16 +16,17 @@ class Pruner(object):
Args: Args:
model(paddle.nn.Layer): The target model to be pruned. model(paddle.nn.Layer): The target model to be pruned.
input_shape(list<int>): The input shape of model. It is used to trace the graph of the model. input_shape(list<int>): The input shape of model. It is used to trace the graph of the model.
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
""" """
def __init__(self, model, inputs): def __init__(self, model, inputs, opt=None):
self.model = model self.model = model
self.inputs = inputs self.inputs = inputs
self._var_shapes = {} self._var_shapes = {}
for var in model.parameters(): for var in model.parameters():
self._var_shapes[var.name] = var.shape self._var_shapes[var.name] = var.shape
self.plan = None self.plan = None
self.opt = opt
def status(self, data=None, eval_func=None, status_file=None): def status(self, data=None, eval_func=None, status_file=None):
raise NotImplemented("status is not implemented") raise NotImplemented("status is not implemented")
...@@ -53,6 +54,6 @@ class Pruner(object): ...@@ -53,6 +54,6 @@ class Pruner(object):
if apply == "lazy": if apply == "lazy":
global_plan.apply(self.model, lazy=True) global_plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
global_plan.apply(self.model, lazy=False) global_plan.apply(self.model, lazy=False, opt=self.opt)
self.plan = global_plan self.plan = global_plan
return global_plan return global_plan
...@@ -95,11 +95,77 @@ class PruningPlan(): ...@@ -95,11 +95,77 @@ class PruningPlan():
for name, mask in self._masks.items() for name, mask in self._masks.items()
]) + details ]) + details
def apply(self, model, lazy=False): def _prune_opt(self, param_name, dims, bool_mask, opt):
if opt is None:
return
for k, v in opt._accumulators.items():
var_tmp = v.get(param_name)
#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
if var_tmp is None or var_tmp.shape == [1]:
if var_tmp is not None: print(var_tmp.name, var_tmp.shape)
continue
t_value = var_tmp.value().get_tensor()
value = np.array(t_value).astype("float32")
pruned_value = np.apply_along_axis(lambda data: data[bool_mask],
dims, value)
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place)
def _buffer_opt(self, param_name, sub_layer, opt):
if opt is None:
return
for k, v in opt._accumulators.items():
var_tmp = v.get(param_name)
if var_tmp is None: continue
backup_name = var_tmp.name.replace(".", "_") + "_backup"
if backup_name not in sub_layer._buffers:
sub_layer.register_buffer(
backup_name, paddle.to_tensor(var_tmp.value().get_tensor()))
_logger.debug("Backup values of {} into buffers.".format(
var_tmp.name))
def _restore_opt(self, param_name, sub_layer, opt):
if opt is None:
return
for k, v in opt._accumulators.items():
var_tmp = v.get(param_name)
if var_tmp is None: continue
backup_name = var_tmp.name.replace(".", "_") + "_backup"
if backup_name in sub_layer._buffers:
_logger.debug("Restore values of variable: {}".format(
var_tmp.name))
t_value = var_tmp.value().get_tensor()
t_backup = sub_layer._buffers[backup_name].value().get_tensor()
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(np.array(t_backup).astype("float32"), place)
del sub_layer._buffers[backup_name]
def apply(self, model, lazy=False, opt=None):
if lazy: if lazy:
self.lazy_apply(model) self.lazy_apply(model)
else: else:
self.imperative_apply(model) self.imperative_apply(model, opt)
def lazy_apply(self, model): def lazy_apply(self, model):
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers():
...@@ -136,13 +202,13 @@ class PruningPlan(): ...@@ -136,13 +202,13 @@ class PruningPlan():
t_value.set(value * expand_mask, place) t_value.set(value * expand_mask, place)
def imperative_apply(self, model): def imperative_apply(self, model, opt=None):
""" """
Pruning values of variable imperatively. It is valid when pruning Pruning values of variable imperatively. It is valid when pruning
on one dimension. on one dimension.
""" """
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks: if param.name in self._masks:
for _mask in self._masks[param.name]: for _mask in self._masks[param.name]:
...@@ -152,7 +218,6 @@ class PruningPlan(): ...@@ -152,7 +218,6 @@ class PruningPlan():
bool_mask = np.array(mask).astype(bool) bool_mask = np.array(mask).astype(bool)
t_value = param.value().get_tensor() t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups') groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len( if dims == 1 and groups is not None and groups > 1 and len(
value.shape) == 4: value.shape) == 4:
...@@ -173,8 +238,13 @@ class PruningPlan(): ...@@ -173,8 +238,13 @@ class PruningPlan():
paddle.to_tensor(value)) paddle.to_tensor(value))
_logger.debug("Backup values of {} into buffers.". _logger.debug("Backup values of {} into buffers.".
format(param.name)) format(param.name))
# save optimizer accumulators into layer buffer
self._buffer_opt(param.name, sub_layer, opt)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims, value) lambda data: data[bool_mask], dims, value)
self._prune_opt(param.name, dims, bool_mask, opt)
p = t_value._place() p = t_value._place()
if p.is_cpu_place(): if p.is_cpu_place():
place = paddle.CPUPlace() place = paddle.CPUPlace()
...@@ -184,16 +254,17 @@ class PruningPlan(): ...@@ -184,16 +254,17 @@ class PruningPlan():
p = core.Place() p = core.Place()
p.set_place(t_value._place()) p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id()) place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place) t_value.set(pruned_value, place)
# for training # for training
if param.trainable: if param.trainable:
param.clear_gradient() param.clear_gradient()
def restore(self, model): def restore(self, model, opt=None):
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
# restore optimizer accumulators from layer buffer
self._restore_opt(param.name, sub_layer, opt)
backup_name = "_".join([param.name.replace(".", "_"), "backup"]) backup_name = "_".join([param.name.replace(".", "_"), "backup"])
if backup_name in sub_layer._buffers: if backup_name in sub_layer._buffers:
_logger.debug("Restore values of variable: {}".format( _logger.debug("Restore values of variable: {}".format(
......
...@@ -17,9 +17,10 @@ class DygraphPruningCollections(PruningCollections): ...@@ -17,9 +17,10 @@ class DygraphPruningCollections(PruningCollections):
Args: Args:
- model(nn.Layer): The dygraph to be parsed. - model(nn.Layer): The dygraph to be parsed.
- inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`. - inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`.
- skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs): def __init__(self, model, inputs, skip_leaves=True):
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training. # model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs) program = dygraph2program(model, inputs=inputs)
...@@ -28,7 +29,8 @@ class DygraphPruningCollections(PruningCollections): ...@@ -28,7 +29,8 @@ class DygraphPruningCollections(PruningCollections):
_param.name for _param in model.parameters() _param.name for _param in model.parameters()
if len(_param.shape) == 4 if len(_param.shape) == 4
] ]
self._collections = self.create_pruning_collections(params, graph) self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves)
_logger.info("Found {} collections.".format(len(self._collections))) _logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {} _name2values = {}
......
...@@ -139,7 +139,8 @@ class PruningCollections(object): ...@@ -139,7 +139,8 @@ class PruningCollections(object):
params, params,
graph, graph,
skip_stranger=True, skip_stranger=True,
skip_vars=None): skip_vars=None,
skip_leaves=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -164,7 +165,8 @@ class PruningCollections(object): ...@@ -164,7 +165,8 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True. skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None. skip_vars(list<str>): Names of variables that will be skipped. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
Returns: Returns:
list<Group>: The groups. list<Group>: The groups.
...@@ -173,12 +175,12 @@ class PruningCollections(object): ...@@ -173,12 +175,12 @@ class PruningCollections(object):
if not isinstance(graph, GraphWrapper): if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
if skip_vars is None: skip_vars = [] if skip_vars is None else skip_vars
skip_vars = self._find_leaves(graph) if skip_leaves:
leaves = self._find_leaves(graph)
skip_vars.extend(leaves)
_logger.warning( _logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.". "Leaves {} will be skipped when parsing graph.".format(leaves))
format(skip_vars))
visited = {} visited = {}
collections = [] collections = []
unsupported_warnings = set() unsupported_warnings = set()
...@@ -234,7 +236,7 @@ class PruningCollections(object): ...@@ -234,7 +236,7 @@ class PruningCollections(object):
class StaticPruningCollections(PruningCollections): class StaticPruningCollections(PruningCollections):
def __init__(self, params, graph, skip_stranger=True): def __init__(self, params, graph, skip_stranger=True, skip_leaves=True):
super(StaticPruningCollections, self).__init__() super(StaticPruningCollections, self).__init__()
self._collections = self.create_pruning_collections( self._collections = self.create_pruning_collections(
params, graph, skip_stranger=skip_stranger) params, graph, skip_stranger=skip_stranger, skip_leaves=skip_leaves)
...@@ -527,77 +527,6 @@ class split(PruneWorker): ...@@ -527,77 +527,6 @@ class split(PruneWorker):
self._visit_and_search(out_var, pruned_axis, transforms) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(concat, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
self._visit(var, pruned_axis)
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
transoform = {
'src_start': start,
'src_end': start + in_var.shape()[pruned_axis],
'target_start': 0,
'target_end': in_var.shape()[pruned_axis],
'target_len': in_var.shape()[pruned_axis],
'stride': 1
}
start += in_var.shape()[pruned_axis]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis,
transforms + [transoform])
else:
for _, in_var in enumerate(self.op.inputs("X")):
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"):
self._visit(var, pruned_axis)
if axis == pruned_axis:
idx = []
target_start = 0
for v in self.op.inputs("X"):
if v.name() != var.name():
target_start += v.shape()[pruned_axis]
else:
break
target_end = target_start + v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
next_ops = out_var.outputs()
transform = {
'src_start': 0,
'src_end': var.shape()[pruned_axis],
'target_start': target_start,
'target_end': target_end,
'target_len': out_var.shape()[pruned_axis],
'stride': 1
}
self._visit(out_var, pruned_axis)
for op in next_ops:
# The output of concat can be visited repeatedly
c_visited = {}
self._prune_op(
op,
out_var,
pruned_axis,
transforms + [transform],
visited=c_visited)
# Add nodes searched from concat into global visited array.
self.visited.update(c_visited)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker): class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
...@@ -620,21 +549,21 @@ class depthwise_conv2d(PruneWorker): ...@@ -620,21 +549,21 @@ class depthwise_conv2d(PruneWorker):
pruned_axis) pruned_axis)
# pruning number of filters # pruning number of filters
assert (_filter.shape()[0] % _groups == 0) assert (_filter.shape()[0] % _groups == 0)
stride = _filter.shape()[0] / _groups repeat = int(_filter.shape()[0] / _groups)
self.append_pruned_vars(_filter, 0, transforms + [{ self.append_pruned_vars(_filter, 0, transforms + [{
"stride": stride "repeat": repeat
}]) }])
# kernel_number * groups will be pruned by reducing groups # kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms) self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms + [{ self._visit_and_search(_filter, 0, transforms + [{
"stride": stride "repeat": repeat
}]) }])
# It will not pruning number of kernels in depthwise conv2d, # It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators. # so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms) # self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1) self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms + [{ self._visit_and_search(_out, channel_axis, transforms + [{
"stride": stride "repeat": repeat
}]) }])
elif var == _filter: elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
...@@ -659,20 +588,98 @@ class mul(PruneWorker): ...@@ -659,20 +588,98 @@ class mul(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(mul, self).__init__(op, pruned_params, visited, skip_stranger) super(mul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, trans):
if var in self.op.inputs("X"): x_num_col_dims = self.op.attr("x_num_col_dims")
assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format( y_num_col_dims = self.op.attr("y_num_col_dims")
pruned_axis) x = self.op.inputs("X")[0]
idx = [] y = self.op.inputs("Y")[0]
feature_map_size = var.shape()[2] * var.shape()[3] out = self.op.outputs("Out")[0]
range_idx = np.array(range(feature_map_size)) x_shape = x.shape()
for i in pruned_idx: y_shape = y.shape()
idx += list(range_idx + i * feature_map_size) if var == x:
param_var = self.op.inputs("Y")[0] if y_num_col_dims > 1 and pruned_axis >= x_num_col_dims:
self.append_pruned_vars(param_var, 0, idx) raise UnsupportOpError(
"Unsupport pruning x of mul when y_num_col_dims > 1 and pruned_axis >= x_num_col_dims"
)
tile = 1
repeat = 1
if pruned_axis < x_num_col_dims:
for i in range(0, pruned_axis):
tile *= x_shape[i]
for i in range(pruned_axis + 1, x_num_col_dims):
repeat *= x_shape[i]
self.append_pruned_vars(out, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(out, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
else:
for i in range(x_num_col_dims, pruned_axis):
tile *= x_shape[i]
for i in range(pruned_axis + 1, len(x_shape)):
repeat *= x_shape[i]
self.append_pruned_vars(y, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(y, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
elif var == y:
if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims):
raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
)
tile = 1
repeat = 1
if pruned_axis >= y_num_col_dims:
for i in range(y_num_col_dims, pruned_axis):
tile *= y_shape[i]
for i in range(pruned_axis + 1, len(y_shape)):
repeat *= y_shape[i]
self.append_pruned_vars(out, 1, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(out, 1, trans + [{
"tile": tile,
"repeat": repeat
}])
else:
for i in range(0, pruned_axis):
tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i]
self.append_pruned_vars(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
elif var == out:
if (pruned_axis == 0 and x_num_col_dims != 1) or (
pruned_axis == 1 and (len(y_shape) - y_num_col_dims) != 1):
raise UnsupportOpError(
"Unsupport pruning out of mul when pruned_axis={}; x_num_col_dims: {}; y_num_col_dims: {}; y_shape: {}.".
format(pruned_axis, x_num_col_dims, y_num_col_dims,
y_shape))
for op in param_var.outputs(): if pruned_axis == 0:
self._prune_op(op, param_var, 0, pruned_idx) self.append_pruned_vars(x, 0, trans)
self._visit_and_search(x, 0, trans)
elif pruned_axis == 1:
self.append_pruned_vars(y, len(y_shape) - 1, trans)
self._visit_and_search(y, len(y_shape) - 1, trans)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -684,16 +691,54 @@ class matmul(PruneWorker): ...@@ -684,16 +691,54 @@ class matmul(PruneWorker):
x = self.op.inputs("X")[0] x = self.op.inputs("X")[0]
y = self.op.inputs("Y")[0] y = self.op.inputs("Y")[0]
out = self.op.outputs("Out")[0] out = self.op.outputs("Out")[0]
if var == x and pruned_axis == 1: x_shape_len = len(x.shape())
self.append_pruned_vars(y, 0, pruned_idx) y_shape_len = len(y.shape())
self._visit_and_search(y, 0, pruned_idx) mappings = []
if x_shape_len == 1 and y_shape_len == 1:
mappings = [(0, 0, 0)]
elif x_shape_len == 1 and y_shape_len == 2:
mappings = [(0, 0, -1), (-1, 1, 0)]
elif x_shape_len == 2 and y_shape_len == 2:
mappings = [(0, -1, 0), (1, 0, -1), (-1, 1, 1)]
elif x_shape_len == 3 and y_shape_len == 1:
mappings = [(1, -1, 1), (2, 0, -1)]
elif x_shape_len == 2 and y_shape_len == 3:
mappings = [(0, -1, 1), (1, 1, -1), (-1, 2, 2)]
elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1),
(-1, x_shape_len - 1, x_shape_len - 1)]
if var == x:
for x_i, y_i, out_i in mappings:
if pruned_axis == x_i:
if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
break
if var == y:
for x_i, y_i, out_i in mappings:
if pruned_axis == y_i:
if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx)
self._visit_and_search(x, x_i, pruned_idx)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
break
if var == out: if var == out:
if pruned_axis == 0: for x_i, y_i, out_i in mappings:
self.append_pruned_vars(x, 0, pruned_idx) if pruned_axis == out_i:
self._visit_and_search(x, 0, pruned_idx) if x_i != -1:
elif pruned_axis == 1: self.append_pruned_vars(x, x_i, pruned_idx)
self.append_pruned_vars(y, 1, pruned_idx) self._visit_and_search(x, x_i, pruned_idx)
self._visit_and_search(y, 1, pruned_idx) if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
break
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -859,3 +904,20 @@ class unsqueeze2(PruneWorker): ...@@ -859,3 +904,20 @@ class unsqueeze2(PruneWorker):
squeeze_num += 1 squeeze_num += 1
pruned_axis -= squeeze_num pruned_axis -= squeeze_num
self._visit_and_search(in_var, pruned_axis, transforms) self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class average_accumulates(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(average_accumulates, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
in_var = self.op.inputs("param")[0]
out_var_1 = self.op.outputs("out_sum_1")[0]
out_var_2 = self.op.outputs("out_sum_2")[0]
out_var_3 = self.op.outputs("out_sum_3")[0]
if in_var == var:
self.append_pruned_vars(out_var_1, pruned_axis, transforms)
self.append_pruned_vars(out_var_2, pruned_axis, transforms)
self.append_pruned_vars(out_var_3, pruned_axis, transforms)
...@@ -169,20 +169,25 @@ class Pruner(): ...@@ -169,20 +169,25 @@ class Pruner():
for name, axis, pruned_idx, transforms in items: for name, axis, pruned_idx, transforms in items:
src = pruned_idx src = pruned_idx
for trans in transforms: for trans in transforms:
if 'src_start' not in trans:
continue
src_start = trans['src_start']
src_end = trans['src_end']
src_len = src_end - src_start
target_start = trans['target_start']
target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
target = [] target = []
for idx in src: if 'src_start' in trans:
if idx >= src_start and idx < src_end: src_start = trans['src_start']
idx -= src_start src_end = trans['src_end']
target.extend(list(idx + starts)) src_len = src_end - src_start
target_start = trans['target_start']
target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
for idx in src:
if idx >= src_start and idx < src_end:
idx -= src_start
target.extend(list(idx + starts))
elif "repeat" in trans:
repeat = trans['repeat']
for idx in src:
idx = idx * repeat
target.extend(range(idx, idx + repeat))
src = target src = target
ret.append((name, axis, src)) ret.append((name, axis, src))
return ret return ret
......
...@@ -72,11 +72,11 @@ class TestFilterPruner(unittest.TestCase): ...@@ -72,11 +72,11 @@ class TestFilterPruner(unittest.TestCase):
paddle.metric.Accuracy(topk=(1, 5))) paddle.metric.Accuracy(topk=(1, 5)))
model.fit(self.train_dataset, epochs=1, batch_size=128, verbose=1) model.fit(self.train_dataset, epochs=1, batch_size=128, verbose=1)
pruners = [] pruners = []
pruner = L1NormFilterPruner(net, [1, 1, 28, 28]) pruner = L1NormFilterPruner(net, [1, 1, 28, 28], opt=optimizer)
pruners.append(pruner) pruners.append(pruner)
pruner = FPGMFilterPruner(net, [1, 1, 28, 28]) pruner = FPGMFilterPruner(net, [1, 1, 28, 28], opt=optimizer)
pruners.append(pruner) pruners.append(pruner)
pruner = L2NormFilterPruner(net, [1, 1, 28, 28]) pruner = L2NormFilterPruner(net, [1, 1, 28, 28], opt=optimizer)
pruners.append(pruner) pruners.append(pruner)
def eval_fn(): def eval_fn():
...@@ -90,6 +90,10 @@ class TestFilterPruner(unittest.TestCase): ...@@ -90,6 +90,10 @@ class TestFilterPruner(unittest.TestCase):
eval_func=eval_fn, eval_func=eval_fn,
sen_file=sen_file, sen_file=sen_file,
target_vars=self._param_names) target_vars=self._param_names)
model.fit(self.train_dataset,
epochs=1,
batch_size=128,
verbose=1)
base_acc = eval_fn() base_acc = eval_fn()
plan = pruner.sensitive_prune(0.01) plan = pruner.sensitive_prune(0.01)
pruner.restore() pruner.restore()
...@@ -129,44 +133,68 @@ class TestPruningGroupConv2d(unittest.TestCase): ...@@ -129,44 +133,68 @@ class TestPruningGroupConv2d(unittest.TestCase):
for param in net.parameters(): for param in net.parameters():
if param.name not in shapes: if param.name not in shapes:
shapes[param.name] = param.shape shapes[param.name] = param.shape
assert (shapes[param.name] == param.shape) self.assertTrue(shapes[param.name] == param.shape)
pruner.restore() pruner.restore()
#class TestStrideTransform(unittest.TestCase): from paddle.fluid import ParamAttr
# def __init__(self, methodName='runTest'):
# super(TestStrideTransform, self).__init__(methodName)
# class MulNet(paddle.nn.Layer):
# def runTest(self): """
# with fluid.unique_name.guard(): [3, 36] X conv(x)
# """
# net = paddle.vision.models.mobilenet_v1()
# ratios = {} def __init__(self):
# for param in net.parameters(): super(MulNet, self).__init__()
# if len(param.shape) == 4: self.conv_a = paddle.nn.Conv2D(6, 6, 1)
# ratios[param.name] = 0.5 self.b = self.create_parameter(shape=[3, 36], attr=ParamAttr(name="b"))
# pruners = []
# pruner = L1NormFilterPruner(net, [1, 3, 128, 128]) def forward(self, x):
# pruners.append(pruner) conv_a = self.conv_a(x)
# pruner = FPGMFilterPruner(net, [1, 3, 128, 128]) return paddle.fluid.layers.mul(self.b,
# pruners.append(pruner) conv_a,
# pruner = L2NormFilterPruner(net, [1, 3, 128, 128]) x_num_col_dims=1,
# pruners.append(pruner) y_num_col_dims=3)
#
# shapes = {}
# for pruner in pruners: class TestPruningMul(unittest.TestCase):
# plan = pruner.prune_vars(ratios, 0) def __init__(self, methodName='runTest'):
# for param in net.parameters(): super(TestPruningMul, self).__init__(methodName)
# if param.name not in shapes:
# shapes[param.name] = param.shape def runTest(self):
# assert(shapes[param.name] == param.shape) with fluid.unique_name.guard():
# pruner.restore() net = MulNet()
ratios = {}
ratios['conv2d_0.w_0'] = 0.5
pruners = []
pruner = L1NormFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
pruner = FPGMFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
pruner = L2NormFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
shapes = {
'b': [3, 18],
'conv2d_0.w_0': [3, 6, 1, 1],
'conv2d_0.b_0': [3]
}
for pruner in pruners:
plan = pruner.prune_vars(ratios, 0)
for param in net.parameters():
if param.name not in shapes:
shapes[param.name] = param.shape
self.assertTrue(shapes[param.name] == param.shape)
pruner.restore()
def add_cases(suite): def add_cases(suite):
# suite.addTest(TestStatus()) suite.addTest(TestStatus())
# suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"])) suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"]))
suite.addTest(TestPruningGroupConv2d()) suite.addTest(TestPruningGroupConv2d())
suite.addTest(TestPruningMul())
def load_tests(loader, standard_tests, pattern): def load_tests(loader, standard_tests, pattern):
......
...@@ -15,6 +15,7 @@ import sys ...@@ -15,6 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
from static_case import StaticCase from static_case import StaticCase
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune import Pruner from paddleslim.prune import Pruner
from static_case import StaticCase from static_case import StaticCase
...@@ -103,5 +104,85 @@ class TestPrune(StaticCase): ...@@ -103,5 +104,85 @@ class TestPrune(StaticCase):
self.assertTrue(shapes[param.name] == param.shape) self.assertTrue(shapes[param.name] == param.shape)
class TestSplit(StaticCase):
def test_split(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(input, 4, 3, "conv2")
split_0, split_1 = paddle.split(conv1, 2, axis=1)
add = split_0 + conv2
out = conv_bn_layer(add, 4, 3, "conv3")
out1 = conv_bn_layer(split_1, 4, 4, "conv4")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv2_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (6, 3, 3, 3),
"conv2_weights": (2, 3, 3, 3),
"conv3_weights": (4, 2, 3, 3),
"conv4_weights": (4, 4, 3, 3),
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
class TestMul(StaticCase):
def test_mul(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
fc_0 = paddle.fluid.layers.fc(conv1, size=10)
fc_1 = paddle.fluid.layers.fc(fc_0, size=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv1_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4, 3, 3, 3),
"fc_0.w_0": (1024, 10),
"fc_1.w_0": (10, 10)
}
for param in pruned_program.global_block().all_parameters():
if param.name in shapes.keys():
self.assertTrue(shapes[param.name] == param.shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -324,6 +324,7 @@ class TestPruneWorker(unittest.TestCase): ...@@ -324,6 +324,7 @@ class TestPruneWorker(unittest.TestCase):
if var.name() not in ret: if var.name() not in ret:
ret[var.name()] = [] ret[var.name()] = []
ret[var.name()].append(axis) ret[var.name()].append(axis)
print(f"excepted: {_ret}; but get {ret}")
self.assertTrue(ret == _ret) self.assertTrue(ret == _ret)
...@@ -372,38 +373,44 @@ class TestElementwiseMul(TestPruneWorker): ...@@ -372,38 +373,44 @@ class TestElementwiseMul(TestPruneWorker):
class TestActivation(TestPruneWorker): class TestActivation(TestPruneWorker):
def __init__(self, methodName="test_prune", def __init__(self,
op=paddle.nn.functional.sigmoid): methodName="check",
op=paddle.nn.functional.sigmoid,
**kwargs):
super(TestActivation, self).__init__(methodName) super(TestActivation, self).__init__(methodName)
self.act = op self.act = op
self.kwargs = kwargs
def define_layer(self, input): def define_layer(self, input):
conv1 = paddle.static.nn.conv2d( conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False) input, 3, 3, name="conv1", bias_attr=False)
self.input = conv1 self.input = conv1
tmp = self.act(conv1) tmp = self.act(conv1, **self.kwargs)
self.output = tmp self.output = tmp
conv2 = paddle.static.nn.conv2d( conv2 = paddle.static.nn.conv2d(
tmp, 3, 3, name="conv2", bias_attr=False) tmp, 3, 3, name="conv2", bias_attr=False)
def set_cases(self): def set_cases(self):
self.cases.append((self.in_var, 1, {'conv2.w_0': [1]})) self.cases.append((self.in_var, 1, {'conv2.w_0': [1]}))
self.cases.append((self.out_var, 1, { self.cases.append((self.out_var, 1, {'conv1.w_0': [0], }))
'conv1.w_0': [0],
'conv2.w_0': [1]
}))
def test_prune(self): def check(self):
self.check_in_out() self.check_in_out()
suite = unittest.TestSuite() act_suite = unittest.TestSuite()
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_bilinear)) act_suite.addTest(
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_nearest)) TestActivation(
suite.addTest(TestActivation(op=paddle.floor)) op=paddle.fluid.layers.resize_bilinear, scale=2.))
suite.addTest(TestActivation(op=paddle.scale)) act_suite.addTest(
suite.addTest( TestActivation(
TestActivation(op=paddle.fluid.layers.nn.uniform_random_batch_size_like)) op=paddle.fluid.layers.resize_nearest, scale=2.))
act_suite.addTest(TestActivation(op=paddle.floor))
act_suite.addTest(TestActivation(op=paddle.scale))
act_suite.addTest(
TestActivation(
op=paddle.fluid.layers.nn.uniform_random_batch_size_like,
shape=[8, 8, 16, 16]))
class TestDepthwiseConv2d(TestPruneWorker): class TestDepthwiseConv2d(TestPruneWorker):
...@@ -432,43 +439,161 @@ class TestDepthwiseConv2d(TestPruneWorker): ...@@ -432,43 +439,161 @@ class TestDepthwiseConv2d(TestPruneWorker):
class TestMul(TestPruneWorker): class TestMul(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self,
methodName="check",
x_num_col_dims=1,
y_num_col_dims=1,
ret=[]):
super(TestMul, self).__init__(methodName) super(TestMul, self).__init__(methodName)
self.x_num_col_dims = x_num_col_dims
self.y_num_col_dims = y_num_col_dims
self.ret = ret
def define_layer(self, input): def define_layer(self, input):
x = fluid.data(name="x", shape=[1, 4, 3, 3]) x = fluid.data(name="x", shape=[1, 1, 1, 1])
y = fluid.data(name="y", shape=[36, 7]) y = fluid.data(name="y", shape=[1, 1, 1, 1])
self.input = x self.input = x
out = paddle.fluid.layers.mul(x, y) self.y = y
out = paddle.fluid.layers.mul(x,
y,
x_num_col_dims=self.x_num_col_dims,
y_num_col_dims=self.y_num_col_dims)
self.output = out self.output = out
def set_cases(self): def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]})) y = self.graph.var(self.y.name)
x = self.in_var
def test_prune(self): out = self.out_var
self.cases.append((x, 0, self.ret[0]))
self.cases.append((x, 1, self.ret[1]))
self.cases.append((x, 2, self.ret[2]))
self.cases.append((x, 3, self.ret[3]))
self.cases.append((y, 0, self.ret[4]))
self.cases.append((y, 1, self.ret[5]))
self.cases.append((y, 2, self.ret[6]))
self.cases.append((y, 3, self.ret[7]))
self.cases.append((out, 0, self.ret[8]))
self.cases.append((out, 1, self.ret[9]))
def check(self):
self.check_in_out() self.check_in_out()
mul_suite = unittest.TestSuite()
ret = [{
'mul_0.tmp_0': [0]
}] + [{
'y': [0]
}] * 3 + [{}] + [{
'mul_0.tmp_0': [1]
}] * 3 + [{
'x': [0]
}, {}]
mul_suite.addTest(TestMul(x_num_col_dims=1, y_num_col_dims=1, ret=ret))
ret = [{
'mul_0.tmp_0': [0]
}] * 2 + [{}] * 4 + [{
'mul_0.tmp_0': [1]
}] * 2 + [{}] * 2
mul_suite.addTest(TestMul(x_num_col_dims=2, y_num_col_dims=2, ret=ret))
ret = [{
'mul_0.tmp_0': [0]
}] * 3 + [{}] + [{
'x': [3]
}] * 3 + [{
'mul_0.tmp_0': [1]
}] + [{}, {
'y': [3]
}]
mul_suite.addTest(TestMul(x_num_col_dims=3, y_num_col_dims=3, ret=ret))
class TestMatmul(TestPruneWorker): class TestMatmul(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self, methodName="test_prune"):
super(TestMatmul, self).__init__(methodName) super(TestMatmul, self).__init__(methodName)
self.x_shape = [6, 8]
self.y_shape = [8, 7]
def define_layer(self, input): def define_layer(self, input):
x = fluid.data(name="x", shape=[6, 8]) x = fluid.data(name="x", shape=self.x_shape)
y = fluid.data(name="y", shape=[8, 7]) y = fluid.data(name="y", shape=self.y_shape)
self.input = x self.input = x
self.y = y
out = paddle.matmul(x, y) out = paddle.matmul(x, y)
self.output = out self.output = out
def set_cases(self): def set_cases(self):
self.y_var = self.graph.var(self.y.name)
self.cases.append((self.in_var, 1, {'y': [0]})) self.cases.append((self.in_var, 1, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0]})) self.cases.append((self.y_var, 0, {'x': [1]}))
self.cases.append((self.out_var, 1, {'y': [1]})) self.cases.append((self.out_var, 1, {'y': [1]}))
def test_prune(self): def test_prune(self):
self.check_in_out() self.check_in_out()
class TestMatmulCase2(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase2, self).__init__(methodName)
self.x_shape = [8]
self.y_shape = [7]
def set_cases(self):
self.cases.append((self.in_var, 0, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0], 'y': [0]}))
class TestMatmulCase3(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase3, self).__init__(methodName)
self.x_shape = [7]
self.y_shape = [7, 8]
def set_cases(self):
self.cases.append((self.in_var, 0, {'y': [0]}))
self.cases.append((self.out_var, 0, {'y': [1]}))
class TestMatmulCase4(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase4, self).__init__(methodName)
self.x_shape = [8, 7, 7]
self.y_shape = [7]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 2, {'y': [0]}))
self.cases.append((self.out_var, 1, {'x': [1]}))
class TestMatmulCase5(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase5, self).__init__(methodName)
self.x_shape = [7, 7]
self.y_shape = [7, 8, 9]
def set_cases(self):
self.cases.append((self.in_var, 0, {}))
self.cases.append((self.in_var, 1, {'y': [1]}))
self.cases.append((self.out_var, 1, {'x': [0]}))
self.cases.append((self.out_var, 2, {'y': [2]}))
class TestMatmulCase6(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase6, self).__init__(methodName)
self.x_shape = [7, 7, 7]
self.y_shape = [7, 7, 9]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 2, {'y': [1]}))
self.cases.append((self.out_var, 1, {'x': [1]}))
self.cases.append((self.out_var, 2, {'y': [2]}))
class TestSplit(TestPruneWorker): class TestSplit(TestPruneWorker):
def define_layer(self, input): def define_layer(self, input):
self.input = input self.input = input
...@@ -528,6 +653,33 @@ class TestAdam(TestPruneWorker): ...@@ -528,6 +653,33 @@ class TestAdam(TestPruneWorker):
self.check_in_out() self.check_in_out()
class TestAverageAccumulates(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Adam()
opt.minimize(out)
model_average = fluid.optimizer.ModelAverage(
0.15, min_average_window=10000, max_average_window=12500)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_moment1_0': [0],
'conv1.w_0_moment2_0': [0],
'conv1.w_0_sum_1_0': [0],
'conv1.w_0_sum_2_0': [0],
'conv1.w_0_sum_3_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAffineChannel(TestPruneWorker): class TestAffineChannel(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self, methodName="test_prune"):
super(TestAffineChannel, self).__init__(methodName) super(TestAffineChannel, self).__init__(methodName)
...@@ -555,4 +707,7 @@ class TestAffineChannel(TestPruneWorker): ...@@ -555,4 +707,7 @@ class TestAffineChannel(TestPruneWorker):
if __name__ == '__main__': if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(mul_suite)
runner.run(act_suite)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册