未验证 提交 4f99291e 编写于 作者: M minghaoBD 提交者: GitHub

fix no module named paddle.nn.norm error after paddle API update, test=develop (#729)

上级 8cc4e44c
......@@ -36,6 +36,10 @@ paddleslim>=2.1.0
- 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。
```python
NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ]
def _get_skip_params(model):
"""
This function is used to check whether the given model's layers are valid to be pruned.
......@@ -49,7 +53,7 @@ def _get_skip_params(model):
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__:
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name())
return skip_params
```
......
......@@ -7,6 +7,12 @@ __all__ = ["UnstructuredPruner"]
_logger = get_logger(__name__, level=logging.INFO)
NORMS_ALL = [
'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm'
]
class UnstructuredPruner():
"""
......@@ -164,8 +170,7 @@ class UnstructuredPruner():
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[
-1] in paddle.nn.norm.__all__:
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name())
return skip_params
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册