未验证 提交 4f99291e 编写于 作者: M minghaoBD 提交者: GitHub

fix no module named paddle.nn.norm error after paddle API update, test=develop (#729)

上级 8cc4e44c
...@@ -36,6 +36,10 @@ paddleslim>=2.1.0 ...@@ -36,6 +36,10 @@ paddleslim>=2.1.0
- 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 - 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。
```python ```python
NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ]
def _get_skip_params(model): def _get_skip_params(model):
""" """
This function is used to check whether the given model's layers are valid to be pruned. This function is used to check whether the given model's layers are valid to be pruned.
...@@ -49,7 +53,7 @@ def _get_skip_params(model): ...@@ -49,7 +53,7 @@ def _get_skip_params(model):
""" """
skip_params = set() skip_params = set()
for _, sub_layer in model.named_sublayers(): for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__: if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
return skip_params return skip_params
``` ```
......
...@@ -7,6 +7,12 @@ __all__ = ["UnstructuredPruner"] ...@@ -7,6 +7,12 @@ __all__ = ["UnstructuredPruner"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
NORMS_ALL = [
'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm'
]
class UnstructuredPruner(): class UnstructuredPruner():
""" """
...@@ -164,8 +170,7 @@ class UnstructuredPruner(): ...@@ -164,8 +170,7 @@ class UnstructuredPruner():
""" """
skip_params = set() skip_params = set()
for _, sub_layer in model.named_sublayers(): for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[ if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
-1] in paddle.nn.norm.__all__:
skip_params.add(sub_layer.full_name()) skip_params.add(sub_layer.full_name())
return skip_params return skip_params
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册