未验证 提交 23cc74de 编写于 作者: M minghaoBD 提交者: GitHub

[Unstructured prune]Support block sparse training (#1000)

上级 d096d3ab
...@@ -33,7 +33,7 @@ paddleslim>=2.2.0 ...@@ -33,7 +33,7 @@ paddleslim>=2.2.0
默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作: 默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作:
- 开发者可以通过重写`paddleslim.dygraph.prune.unstructured_pruner.py`中的`UnstructuredPruner.mask_parameters()``UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 - 开发者可以通过重写`paddleslim.dygraph.prune.unstructured_pruner.py`中的`UnstructuredPruner.mask_parameters()``UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 - 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数`bias` 不参与剪裁。
```python ```python
NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
...@@ -55,6 +55,8 @@ def _get_skip_params(model): ...@@ -55,6 +55,8 @@ def _get_skip_params(model):
for _, sub_layer in model.named_sublayers(): for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
for param in sub_layer.parameters(include_sublayers=False):
if len(param.shape) == 1: skip_params.add(param.name)
return skip_params return skip_params
``` ```
......
...@@ -35,6 +35,7 @@ add_arg('pruning_mode', str, 'ratio', "the pruning mod ...@@ -35,6 +35,7 @@ add_arg('pruning_mode', str, 'ratio', "the pruning mod
add_arg('threshold', float, 0.01, "The threshold to set zeros. Default: 0.01") add_arg('threshold', float, 0.01, "The threshold to set zeros. Default: 0.01")
add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120") add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
parser.add_argument('--sparse_block', nargs='+', type=int, default=[1, 1], help="There must be two integers inside this array. The array defines the shape of the block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1, 1]")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'. Default: imagenet") add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'. Default: imagenet")
add_arg('log_period', int, 100, "Log period in batches. Default: 100") add_arg('log_period', int, 100, "Log period in batches. Default: 100")
add_arg('test_period', int, 5, "Test period in epoches. Default: 5") add_arg('test_period', int, 5, "Test period in epoches. Default: 5")
...@@ -100,13 +101,15 @@ def create_unstructured_pruner(model, args, configs=None): ...@@ -100,13 +101,15 @@ def create_unstructured_pruner(model, args, configs=None):
ratio=args.ratio, ratio=args.ratio,
threshold=args.threshold, threshold=args.threshold,
prune_params_type=args.prune_params_type, prune_params_type=args.prune_params_type,
local_sparsity=args.local_sparsity) local_sparsity=args.local_sparsity,
sparse_block=args.sparse_block)
else: else:
return GMPUnstructuredPruner( return GMPUnstructuredPruner(
model, model,
ratio=args.ratio, ratio=args.ratio,
prune_params_type=args.prune_params_type, prune_params_type=args.prune_params_type,
local_sparsity=args.local_sparsity, local_sparsity=args.local_sparsity,
sparse_block=args.sparse_block,
configs=configs) configs=configs)
......
...@@ -42,7 +42,7 @@ tar -xf MobileNetV1_pretrained.tar ...@@ -42,7 +42,7 @@ tar -xf MobileNetV1_pretrained.tar
默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作: 默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作:
- 可以通过重写`paddleslim.prune.unstructured_pruner.py`中的`UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 - 可以通过重写`paddleslim.prune.unstructured_pruner.py`中的`UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.prune.unstructured_pruner._get_skip_params()`)。默认为所有的归一化层的参数不参与剪裁。 - 可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.prune.unstructured_pruner._get_skip_params()`)。默认为所有的归一化层的参数`bias` 不参与剪裁。
```python ```python
def _get_skip_params(program): def _get_skip_params(program):
...@@ -61,6 +61,9 @@ def _get_skip_params(program): ...@@ -61,6 +61,9 @@ def _get_skip_params(program):
if 'norm' in op.type() and 'grad' not in op.type(): if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
for param in program.all_parameters():
if len(param.shape) == 1:
skip_params.add(param.name)
return skip_params return skip_params
``` ```
......
...@@ -39,6 +39,7 @@ add_arg('pruning_mode', str, 'ratio', "the pruning mod ...@@ -39,6 +39,7 @@ add_arg('pruning_mode', str, 'ratio', "the pruning mod
add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55") add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55")
add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120") add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
parser.add_argument('--sparse_block', nargs='+', type=int, default=[1, 1], help="There must be two integers inside this array. The array defines the shape of the block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1, 1]")
add_arg('data', str, "imagenet", "Which data to use. 'mnist', 'cifar10' or 'imagenet'. Default: imagenet") add_arg('data', str, "imagenet", "Which data to use. 'mnist', 'cifar10' or 'imagenet'. Default: imagenet")
add_arg('log_period', int, 100, "Log period in batches. Default: 100") add_arg('log_period', int, 100, "Log period in batches. Default: 100")
add_arg('test_period', int, 5, "Test period in epoches. Default: 5") add_arg('test_period', int, 5, "Test period in epoches. Default: 5")
...@@ -102,7 +103,8 @@ def create_unstructured_pruner(train_program, args, place, configs): ...@@ -102,7 +103,8 @@ def create_unstructured_pruner(train_program, args, place, configs):
threshold=args.threshold, threshold=args.threshold,
prune_params_type=args.prune_params_type, prune_params_type=args.prune_params_type,
place=place, place=place,
local_sparsity=args.local_sparsity) local_sparsity=args.local_sparsity,
sparse_block=args.sparse_block)
else: else:
return GMPUnstructuredPruner( return GMPUnstructuredPruner(
train_program, train_program,
...@@ -110,6 +112,7 @@ def create_unstructured_pruner(train_program, args, place, configs): ...@@ -110,6 +112,7 @@ def create_unstructured_pruner(train_program, args, place, configs):
prune_params_type=args.prune_params_type, prune_params_type=args.prune_params_type,
place=place, place=place,
local_sparsity=args.local_sparsity, local_sparsity=args.local_sparsity,
sparse_block=args.sparse_block,
configs=configs) configs=configs)
...@@ -312,7 +315,6 @@ def compress(args): ...@@ -312,7 +315,6 @@ def compress(args):
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
# GMP pruner step 2: step() to update ratios and other internal states of the pruner. # GMP pruner step 2: step() to update ratios and other internal states of the pruner.
pruner.step() pruner.step()
train_run_cost += time.time() - train_start train_run_cost += time.time() - train_start
total_samples += args.batch_size total_samples += args.batch_size
loss_n = np.mean(loss_n) loss_n = np.mean(loss_n)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
UnstructuredPruner UnstructuredPruner
---------- ----------
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False) .. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False, sparse_block=[1,1])
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
...@@ -17,7 +17,7 @@ UnstructuredPruner ...@@ -17,7 +17,7 @@ UnstructuredPruner
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。 - **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。 - **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。 - **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。示例代码如下: - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数 bias 不参与剪裁。示例代码如下:
.. code-block:: python .. code-block:: python
...@@ -28,7 +28,7 @@ UnstructuredPruner ...@@ -28,7 +28,7 @@ UnstructuredPruner
def _get_skip_params(model): def _get_skip_params(model):
""" """
This function is used to check whether the given model's layers are valid to be pruned. This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters. Usually, the convolutions are to be pruned while we skip the normalization-related parameters and bias.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance. Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args: Args:
...@@ -40,11 +40,14 @@ UnstructuredPruner ...@@ -40,11 +40,14 @@ UnstructuredPruner
for _, sub_layer in model.named_sublayers(): for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
for param in sub_layer.parameters(include_sublayers=False):
if len(param.shape) == 1: skip_params.add(param.name)
return skip_params return skip_params
.. ..
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **sparse_block(Array<Integer>)** - 一个含有两个正整数的数组,定义了稀疏化时候block的大小。即 sparse_block[0] x sparse_block[1]中的参数作为一个整体,要么被置0,要么保持不变。默认为 [1,1],代表非结构化稀疏。
**返回:** 一个UnstructuredPruner类的实例。 **返回:** 一个UnstructuredPruner类的实例。
...@@ -245,6 +248,7 @@ GMPUnstructuredPruner ...@@ -245,6 +248,7 @@ GMPUnstructuredPruner
- **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。 - **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **sparse_block(Array<Integer>)** - 一个含有两个正整数的数组,定义了稀疏化时候block的大小。即 sparse_block[0] x sparse_block[1]中的参数作为一个整体,要么被置0,要么保持不变。默认为 [1,1],代表非结构化稀疏。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下: - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下:
.. code-block:: python .. code-block:: python
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
UnstrucuturedPruner UnstrucuturedPruner
---------- ----------
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None, local_sparsity=False) .. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None, local_sparsity=False, sparse_block=[1,1])
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
...@@ -19,14 +19,14 @@ UnstrucuturedPruner ...@@ -19,14 +19,14 @@ UnstrucuturedPruner
- **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。 - **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。 - **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。 - **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数和 bias 不参与剪裁。
.. code-block:: python .. code-block:: python
def _get_skip_params(program): def _get_skip_params(program):
""" """
The function is used to get a set of all the skipped parameters when performing pruning. The function is used to get a set of all the skipped parameters when performing pruning.
By default, the normalization-related ones will not be pruned. By default, the normalization-related ones and bias will not be pruned.
Developers could replace it by passing their own function when initializing the UnstructuredPruner instance. Developers could replace it by passing their own function when initializing the UnstructuredPruner instance.
Args: Args:
- program(paddle.static.Program): the current model. - program(paddle.static.Program): the current model.
...@@ -39,11 +39,15 @@ UnstrucuturedPruner ...@@ -39,11 +39,15 @@ UnstrucuturedPruner
if 'norm' in op.type() and 'grad' not in op.type(): if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
for param in program.all_parameters():
if len(param.shape) == 1:
skip_params.add(param.name)
return skip_params return skip_params
.. ..
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **sparse_block(Array<Integer>)** - 一个含有两个正整数的数组,定义了稀疏化时候block的大小。即 sparse_block[0] x sparse_block[1]中的参数作为一个整体,要么被置0,要么保持不变。默认为 [1,1],代表非结构化稀疏。
**返回:** 一个UnstructuredPruner类的实例 **返回:** 一个UnstructuredPruner类的实例
...@@ -339,6 +343,8 @@ GMPUnstrucuturedPruner ...@@ -339,6 +343,8 @@ GMPUnstrucuturedPruner
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。 - **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **sparse_block(Array<Integer>)** - 一个含有两个正整数的数组,定义了稀疏化时候block的大小。即 sparse_block[0] x sparse_block[1]中的参数作为一个整体,要么被置0,要么保持不变。默认为 [1,1],代表非结构化稀疏。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下: - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下:
.. code-block:: python .. code-block:: python
......
...@@ -2,6 +2,8 @@ import numpy as np ...@@ -2,6 +2,8 @@ import numpy as np
import paddle import paddle
import logging import logging
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.prune.unstructured_pruner_utils import *
import copy
__all__ = ["UnstructuredPruner", "GMPUnstructuredPruner"] __all__ = ["UnstructuredPruner", "GMPUnstructuredPruner"]
...@@ -25,6 +27,7 @@ class UnstructuredPruner(): ...@@ -25,6 +27,7 @@ class UnstructuredPruner():
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- sparse_block(Array<Integer>): There must be two integers inside this array. The array defines the shape of the block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1,1]
""" """
def __init__(self, def __init__(self,
...@@ -34,18 +37,28 @@ class UnstructuredPruner(): ...@@ -34,18 +37,28 @@ class UnstructuredPruner():
ratio=0.55, ratio=0.55,
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False): local_sparsity=False,
sparse_block=[1, 1]):
assert mode in ('ratio', 'threshold' assert mode in ('ratio', 'threshold'
), "mode must be selected from 'ratio' and 'threshold'" ), "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
if local_sparsity: if local_sparsity:
assert mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly." assert mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
assert len(
sparse_block
) == 2 and sparse_block[0] > 0 and sparse_block[1] > 0 and isinstance(
sparse_block[0], int
) and isinstance(
sparse_block[1], int
), "Please make sure you provide a valid sparse block, in which there are two positive integers."
self.model = model self.model = model
self.mode = mode self.mode = mode
self.threshold = threshold self.threshold = threshold
self.ratio = ratio self.ratio = ratio
self.local_sparsity = local_sparsity self.local_sparsity = local_sparsity
self.thresholds = {} self.thresholds = {}
self.sparse_block = sparse_block
# Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params # Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params
if skip_params_func is not None: if skip_params_func is not None:
...@@ -97,6 +110,17 @@ class UnstructuredPruner(): ...@@ -97,6 +110,17 @@ class UnstructuredPruner():
continue continue
t_param = param.value().get_tensor() t_param = param.value().get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
if (self.sparse_block[0] * self.sparse_block[1] / v_param.size
>= BLOCK_SPARSE_ACCURATE_THRESHOLD):
print(
"Your sparse block size {} might be too large for the param {} with shape {}, the sparsity of this param might not be precise. Please decrease your sparse block size if possible. Currently, sparse_block[0] ({}) X sparse_block[1] ({}) / weight_count ({}) >= {}".
format(self.sparse_block, param, v_param.shape,
self.sparse_block[0], self.sparse_block[1],
v_param.size, BLOCK_SPARSE_ACCURATE_THRESHOLD))
v_param = cal_mxn_avg_matrix(
v_param, m=self.sparse_block[0], n=self.sparse_block[1])
if self.local_sparsity: if self.local_sparsity:
flatten_v_param = v_param.flatten() flatten_v_param = v_param.flatten()
cur_length = flatten_v_param.size cur_length = flatten_v_param.size
...@@ -119,11 +143,13 @@ class UnstructuredPruner(): ...@@ -119,11 +143,13 @@ class UnstructuredPruner():
if param.name in self.skip_params: if param.name in self.skip_params:
continue continue
mask = self.masks.get(param.name) mask = self.masks.get(param.name)
v_param = np.array(param.value().get_tensor())
v_param_avg = cal_mxn_avg_matrix(
v_param, m=self.sparse_block[0], n=self.sparse_block[1])
if self.local_sparsity: if self.local_sparsity:
bool_tmp = ( bool_tmp = (abs(v_param_avg) >= self.thresholds[param.name])
paddle.abs(param) >= self.thresholds[param.name])
else: else:
bool_tmp = (paddle.abs(param) >= self.threshold) bool_tmp = (abs(v_param_avg) >= self.threshold)
paddle.assign(bool_tmp, output=mask) paddle.assign(bool_tmp, output=mask)
def set_static_masks(self): def set_static_masks(self):
...@@ -234,7 +260,7 @@ class UnstructuredPruner(): ...@@ -234,7 +260,7 @@ class UnstructuredPruner():
def _get_skip_params(self, model): def _get_skip_params(self, model):
""" """
This function is used to check whether the given model's layers are valid to be pruned. This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters. Usually, the convolutions are to be pruned while we skip the normalization-related parameters and bias.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance. Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args: Args:
...@@ -246,6 +272,9 @@ class UnstructuredPruner(): ...@@ -246,6 +272,9 @@ class UnstructuredPruner():
for _, sub_layer in model.named_sublayers(): for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
# exclude bias whose shape is like (n,)
for param in sub_layer.parameters(include_sublayers=False):
if len(param.shape) == 1: skip_params.add(param.name)
return skip_params return skip_params
def _get_skip_params_conv1x1(self, model): def _get_skip_params_conv1x1(self, model):
...@@ -254,6 +283,8 @@ class UnstructuredPruner(): ...@@ -254,6 +283,8 @@ class UnstructuredPruner():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
# exclude bias whose shape is like (n,)
if len(param.shape) == 1: skip_params.add(param.name)
cond = len(param.shape) == 4 and param.shape[ cond = len(param.shape) == 4 and param.shape[
2] == 1 and param.shape[3] == 1 2] == 1 and param.shape[3] == 1
if not cond: skip_params.add(param.name) if not cond: skip_params.add(param.name)
...@@ -275,6 +306,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -275,6 +306,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- sparse_block(Array<Integer>): There must be two integers inside this array. The array defines a block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1,1]
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None
.. code-block:: python .. code-block:: python
...@@ -296,12 +328,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -296,12 +328,13 @@ class GMPUnstructuredPruner(UnstructuredPruner):
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False, local_sparsity=False,
sparse_block=[1, 1],
configs=None): configs=None):
assert configs is not None, "Configs must be passed in for GMP pruner." assert configs is not None, "Configs must be passed in for GMP pruner."
super(GMPUnstructuredPruner, self).__init__( super(GMPUnstructuredPruner, self).__init__(
model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func, model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func,
local_sparsity) local_sparsity, sparse_block)
self.stable_iterations = configs.get('stable_iterations') self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations') self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations') self.tunning_iterations = configs.get('tunning_iterations')
......
...@@ -29,6 +29,8 @@ from .collections import * ...@@ -29,6 +29,8 @@ from .collections import *
from ..prune import collections from ..prune import collections
from .unstructured_pruner import * from .unstructured_pruner import *
from ..prune import unstructured_pruner from ..prune import unstructured_pruner
from .unstructured_pruner_utils import *
from ..prune import unstructured_pruner_utils
from .idx_selector import * from .idx_selector import *
from ..prune import idx_selector from ..prune import idx_selector
__all__ = [] __all__ = []
...@@ -40,5 +42,6 @@ __all__ += prune_worker.__all__ ...@@ -40,5 +42,6 @@ __all__ += prune_worker.__all__
__all__ += prune_io.__all__ __all__ += prune_io.__all__
__all__ += criterion.__all__ __all__ += criterion.__all__
__all__ += unstructured_pruner.__all__ __all__ += unstructured_pruner.__all__
__all__ += unstructured_pruner_utils.__all__
__all__ += idx_selector.__all__ __all__ += idx_selector.__all__
__all__ += collections.__all__ __all__ += collections.__all__
import numpy as np import numpy as np
from ..common import get_logger from ..common import get_logger
from ..core import GraphWrapper from ..core import GraphWrapper
from paddleslim.prune.unstructured_pruner_utils import *
import paddle import paddle
import copy import copy
...@@ -19,8 +20,9 @@ class UnstructuredPruner(): ...@@ -19,8 +20,9 @@ class UnstructuredPruner():
- scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None. - scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None.
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params and bias. Default: None
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- sparse_block(Array<Integer>): There must be two integers inside this array. The array defines the shape of the block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1,1]
""" """
def __init__(self, def __init__(self,
...@@ -32,18 +34,28 @@ class UnstructuredPruner(): ...@@ -32,18 +34,28 @@ class UnstructuredPruner():
place=None, place=None,
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False): local_sparsity=False,
sparse_block=[1, 1]):
self.mode = mode self.mode = mode
self.ratio = ratio self.ratio = ratio
self.threshold = threshold self.threshold = threshold
self.local_sparsity = local_sparsity self.local_sparsity = local_sparsity
self.thresholds = {} self.thresholds = {}
self.sparse_block = sparse_block
assert self.mode in [ assert self.mode in [
'ratio', 'threshold' 'ratio', 'threshold'
], "mode must be selected from 'ratio' and 'threshold'" ], "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
if self.local_sparsity: if self.local_sparsity:
assert self.mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly." assert self.mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
assert len(
sparse_block
) == 2 and sparse_block[0] > 0 and sparse_block[1] > 0 and isinstance(
sparse_block[0], int
) and isinstance(
sparse_block[1], int
), "Please make sure you provide a valid sparse block, in which there are two positive integers."
self.scope = paddle.static.global_scope() if scope == None else scope self.scope = paddle.static.global_scope() if scope == None else scope
self.place = paddle.static.cpu_places()[0] if place is None else place self.place = paddle.static.cpu_places()[0] if place is None else place
...@@ -161,6 +173,15 @@ class UnstructuredPruner(): ...@@ -161,6 +173,15 @@ class UnstructuredPruner():
continue continue
t_param = self.scope.find_var(param).get_tensor() t_param = self.scope.find_var(param).get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
if (self.sparse_block[0] * self.sparse_block[1] / v_param.size >=
BLOCK_SPARSE_ACCURATE_THRESHOLD):
print(
"Your sparse block size {} might be too large for the param {} with shape {}, the sparsity of this param might not be precise. Please decrease your sparse block size if possible. Currently, sparse_block[0] ({}) X sparse_block[1] ({}) / weight_count ({}) >= {}".
format(self.sparse_block, param, v_param.shape,
self.sparse_block[0], self.sparse_block[1],
v_param.size, BLOCK_SPARSE_ACCURATE_THRESHOLD))
v_param = cal_mxn_avg_matrix(
v_param, m=self.sparse_block[0], n=self.sparse_block[1])
if self.local_sparsity: if self.local_sparsity:
cur_threshold = self._partition_sort(v_param.flatten()) cur_threshold = self._partition_sort(v_param.flatten())
self.thresholds[param] = cur_threshold self.thresholds[param] = cur_threshold
...@@ -188,10 +209,12 @@ class UnstructuredPruner(): ...@@ -188,10 +209,12 @@ class UnstructuredPruner():
t_param = self.scope.find_var(param).get_tensor() t_param = self.scope.find_var(param).get_tensor()
t_mask = self.scope.find_var(mask_name).get_tensor() t_mask = self.scope.find_var(mask_name).get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
v_param_avg = cal_mxn_avg_matrix(
v_param, m=self.sparse_block[0], n=self.sparse_block[1])
if self.local_sparsity: if self.local_sparsity:
v_param[np.abs(v_param) < self.thresholds[param]] = 0 v_param[np.abs(v_param_avg) < self.thresholds[param]] = 0
else: else:
v_param[np.abs(v_param) < self.threshold] = 0 v_param[np.abs(v_param_avg) < self.threshold] = 0
v_mask = (v_param != 0).astype(v_param.dtype) v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place) t_mask.set(v_mask, self.place)
...@@ -265,6 +288,10 @@ class UnstructuredPruner(): ...@@ -265,6 +288,10 @@ class UnstructuredPruner():
if 'norm' in op.type() and 'grad' not in op.type(): if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
# exclude bias whose shape is like (n,)
for param in program.all_parameters():
if len(param.shape) == 1:
skip_params.add(param.name)
return skip_params return skip_params
def _get_skip_params_conv1x1(self, program): def _get_skip_params_conv1x1(self, program):
...@@ -275,6 +302,9 @@ class UnstructuredPruner(): ...@@ -275,6 +302,9 @@ class UnstructuredPruner():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
for param in program.all_parameters(): for param in program.all_parameters():
# exclude bias whose shape is like (n,)
if len(param.shape) == 1:
skip_params.add(param.name)
if not (len(param.shape) == 4 and param.shape[2] == 1 and if not (len(param.shape) == 4 and param.shape[2] == 1 and
param.shape[3] == 1): param.shape[3] == 1):
skip_params.add(param.name) skip_params.add(param.name)
...@@ -322,6 +352,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -322,6 +352,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- sparse_block(Array<Integer>): There must be two integers inside this array. The array defines the shape of the block, the values within which are either sparsified to all zeros or kept original. [1, 1] means unstructured pruning. Default: [1,1]
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below: - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below:
.. code-block:: python .. code-block:: python
...@@ -344,12 +375,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -344,12 +375,13 @@ class GMPUnstructuredPruner(UnstructuredPruner):
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False, local_sparsity=False,
sparse_block=[1, 1],
configs=None): configs=None):
assert configs is not None, "Please pass in a valid config dictionary." assert configs is not None, "Please pass in a valid config dictionary."
super(GMPUnstructuredPruner, self).__init__( super(GMPUnstructuredPruner, self).__init__(
program, 'ratio', ratio, 0.0, scope, place, prune_params_type, program, 'ratio', ratio, 0.0, scope, place, prune_params_type,
skip_params_func, local_sparsity) skip_params_func, local_sparsity, sparse_block)
self.stable_iterations = configs.get('stable_iterations') self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations') self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations') self.tunning_iterations = configs.get('tunning_iterations')
......
import numpy as np
import copy
__all__ = ["BLOCK_SPARSE_ACCURATE_THRESHOLD", "cal_mxn_avg_matrix"]
BLOCK_SPARSE_ACCURATE_THRESHOLD = 0.05
def cal_mxn_avg_matrix(mat, m=1, n=1):
if m == 1 and n == 1: return copy.deepcopy(mat)
avg_mat = np.zeros_like(mat)
rows = len(mat) // m + 1
cols = len(mat[0]) // n + 1
for row in range(rows):
for col in range(cols):
avg_mat[m * row:m * row + m, n * col:n * col + n] = np.mean(mat[
m * row:m * row + m, n * col:n * col + n])
return avg_mat
...@@ -16,6 +16,7 @@ class TestUnstructuredPruner(unittest.TestCase): ...@@ -16,6 +16,7 @@ class TestUnstructuredPruner(unittest.TestCase):
def _gen_model(self): def _gen_model(self):
self.net = mobilenet_v1(num_classes=10, pretrained=False) self.net = mobilenet_v1(num_classes=10, pretrained=False)
self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False) self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False)
self.net_mxn = mobilenet_v1(num_classes=10, pretrained=False)
self.pruner = UnstructuredPruner( self.pruner = UnstructuredPruner(
self.net, mode='ratio', ratio=0.55, local_sparsity=True) self.net, mode='ratio', ratio=0.55, local_sparsity=True)
self.pruner_conv1x1 = UnstructuredPruner( self.pruner_conv1x1 = UnstructuredPruner(
...@@ -24,6 +25,12 @@ class TestUnstructuredPruner(unittest.TestCase): ...@@ -24,6 +25,12 @@ class TestUnstructuredPruner(unittest.TestCase):
ratio=0.55, ratio=0.55,
prune_params_type='conv1x1_only', prune_params_type='conv1x1_only',
local_sparsity=False) local_sparsity=False)
self.pruner_mxn = UnstructuredPruner(
self.net_mxn,
mode='ratio',
ratio=0.55,
local_sparsity=True,
sparse_block=[2, 1])
def test_prune(self): def test_prune(self):
ori_sparsity = UnstructuredPruner.total_sparse(self.net) ori_sparsity = UnstructuredPruner.total_sparse(self.net)
...@@ -69,6 +76,12 @@ class TestUnstructuredPruner(unittest.TestCase): ...@@ -69,6 +76,12 @@ class TestUnstructuredPruner(unittest.TestCase):
cur_sparsity = UnstructuredPruner.total_sparse_conv1x1(self.net_conv1x1) cur_sparsity = UnstructuredPruner.total_sparse_conv1x1(self.net_conv1x1)
self.assertTrue(abs(cur_sparsity - 0.55) < 0.01) self.assertTrue(abs(cur_sparsity - 0.55) < 0.01)
def test_block_prune_mxn(self):
self.pruner_mxn.step()
self.pruner_mxn.update_params()
cur_sparsity = UnstructuredPruner.total_sparse(self.net_mxn)
self.assertTrue(abs(cur_sparsity - 0.55) < 0.01)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -64,7 +64,7 @@ class TestStaticMasks(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestStaticMasks(unittest.TestCase):
sparsity_2 = UnstructuredPruner.total_sparse(net) sparsity_2 = UnstructuredPruner.total_sparse(net)
print(sparsity_0, sparsity_1, sparsity_2) print(sparsity_0, sparsity_1, sparsity_2)
self.assertEqual(sparsity_0, 1.0) self.assertEqual(sparsity_0, 1.0)
self.assertEqual(sparsity_2, 1.0) self.assertLess(abs(sparsity_2 - 1), 0.001)
self.assertLess(sparsity_1, 1.0) self.assertLess(sparsity_1, 1.0)
......
...@@ -54,6 +54,14 @@ class TestUnstructuredPruner(StaticCase): ...@@ -54,6 +54,14 @@ class TestUnstructuredPruner(StaticCase):
place=place, place=place,
prune_params_type='conv1x1_only', prune_params_type='conv1x1_only',
local_sparsity=False) local_sparsity=False)
self.pruner_mxn = UnstructuredPruner(
self.main_program,
'ratio',
scope=self.scope,
place=place,
sparse_block=[5, 5],
prune_params_type='conv1x1_only',
local_sparsity=True)
def test_unstructured_prune(self): def test_unstructured_prune(self):
for param in self.main_program.global_block().all_parameters(): for param in self.main_program.global_block().all_parameters():
...@@ -102,6 +110,18 @@ class TestUnstructuredPruner(StaticCase): ...@@ -102,6 +110,18 @@ class TestUnstructuredPruner(StaticCase):
self.assertTrue( self.assertTrue(
self.pruner.skip_params < self.pruner_conv1x1.skip_params) self.pruner.skip_params < self.pruner_conv1x1.skip_params)
def test_block_pruner_mxn(self):
ori_sparsity = UnstructuredPruner.total_sparse_conv1x1(
self.main_program)
self.pruner_mxn.ratio = 0.50
self.pruner_mxn.step()
self.pruner_mxn.update_params()
cur_sparsity = UnstructuredPruner.total_sparse_conv1x1(
self.main_program)
print('original sparsity: {}.'.format(ori_sparsity))
print('current sparsity: {}.'.format(cur_sparsity))
self.assertGreater(cur_sparsity, ori_sparsity)
def test_sparsity_conv1x1(self): def test_sparsity_conv1x1(self):
ori_sparsity = UnstructuredPruner.total_sparse_conv1x1( ori_sparsity = UnstructuredPruner.total_sparse_conv1x1(
self.main_program) self.main_program)
......
...@@ -59,7 +59,7 @@ class TestStaticMasks(StaticCase): ...@@ -59,7 +59,7 @@ class TestStaticMasks(StaticCase):
sparsity_2 = UnstructuredPruner.total_sparse(main_program) sparsity_2 = UnstructuredPruner.total_sparse(main_program)
print(sparsity_0, sparsity_1, sparsity_2) print(sparsity_0, sparsity_1, sparsity_2)
self.assertEqual(sparsity_0, 1.0) self.assertEqual(sparsity_0, 1.0)
self.assertEqual(sparsity_2, 1.0) self.assertLess(abs(sparsity_2 - 1), 0.001)
self.assertLess(sparsity_1, 1.0) self.assertLess(sparsity_1, 1.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册