From 4f99291e30e47da5ba1434831ae22dc52b613e62 Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Thu, 29 Apr 2021 09:42:58 +0800 Subject: [PATCH] fix no module named paddle.nn.norm error after paddle API update, test=develop (#729) --- demo/dygraph/unstructured_pruning/README.md | 6 +++++- paddleslim/dygraph/prune/unstructured_pruner.py | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index 9af49466..4e737561 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -36,6 +36,10 @@ paddleslim>=2.1.0 - 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 ```python +NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', + 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D', + 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ] + def _get_skip_params(model): """ This function is used to check whether the given model's layers are valid to be pruned. @@ -49,7 +53,7 @@ def _get_skip_params(model): """ skip_params = set() for _, sub_layer in model.named_sublayers(): - if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__: + if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: skip_params.add(sub_layer.full_name()) return skip_params ``` diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index c5234c79..3583ad57 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -7,6 +7,12 @@ __all__ = ["UnstructuredPruner"] _logger = get_logger(__name__, level=logging.INFO) +NORMS_ALL = [ + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', + 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D', + 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' +] + class UnstructuredPruner(): """ @@ -164,8 +170,7 @@ class UnstructuredPruner(): """ skip_params = set() for _, sub_layer in model.named_sublayers(): - if type(sub_layer).__name__.split('.')[ - -1] in paddle.nn.norm.__all__: + if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: skip_params.add(sub_layer.full_name()) return skip_params -- GitLab