diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index 9af494666060e2ab42834b8103a4b77dc34e34b5..4e7375617936c01a8da68c629536ff5f940cdad7 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -36,6 +36,10 @@ paddleslim>=2.1.0 - 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 ```python +NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', + 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D', + 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ] + def _get_skip_params(model): """ This function is used to check whether the given model's layers are valid to be pruned. @@ -49,7 +53,7 @@ def _get_skip_params(model): """ skip_params = set() for _, sub_layer in model.named_sublayers(): - if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__: + if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: skip_params.add(sub_layer.full_name()) return skip_params ``` diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index c5234c792a237698a657f8e7656e77940f817c74..3583ad57d2ed733d7512691fb631fe0d84dd30a4 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -7,6 +7,12 @@ __all__ = ["UnstructuredPruner"] _logger = get_logger(__name__, level=logging.INFO) +NORMS_ALL = [ + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', + 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D', + 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' +] + class UnstructuredPruner(): """ @@ -164,8 +170,7 @@ class UnstructuredPruner(): """ skip_params = set() for _, sub_layer in model.named_sublayers(): - if type(sub_layer).__name__.split('.')[ - -1] in paddle.nn.norm.__all__: + if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: skip_params.add(sub_layer.full_name()) return skip_params