未验证 提交 7c96fd43 编写于 作者: M minghaoBD 提交者: GitHub

[cherry-pick][Unstructured_prune]Fix a bug in skip params function (#931) (#955)

* fix a bug in skip params func
上级 0bec56fe
......@@ -87,6 +87,8 @@ class UnstructuredPruner():
if not self._should_prune_layer(sub_layer):
continue
for param in sub_layer.parameters(include_sublayers=False):
if param.name in self.skip_params:
continue
t_param = param.value().get_tensor()
v_param = np.array(t_param)
params_flatten.append(v_param.flatten())
......@@ -97,8 +99,11 @@ class UnstructuredPruner():
def _update_masks(self):
for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer): continue
if not self._should_prune_layer(sub_layer):
continue
for param in sub_layer.parameters(include_sublayers=False):
if param.name in self.skip_params:
continue
mask = self.masks.get(param.name)
bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(bool_tmp, output=mask)
......@@ -225,7 +230,7 @@ class UnstructuredPruner():
for param in sub_layer.parameters(include_sublayers=False):
cond = len(param.shape) == 4 and param.shape[
2] == 1 and param.shape[3] == 1
if not cond: skip_params.add(sub_layer.full_name())
if not cond: skip_params.add(param.name)
return skip_params
def _should_prune_layer(self, layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册