未验证 提交 203ae656 编写于 作者: M minghaoBD 提交者: GitHub

[Unstructured_prune]Fix a bug in skip params function (#931)

* fix a bug in skip params func
上级 4623dc4d
...@@ -93,6 +93,8 @@ class UnstructuredPruner(): ...@@ -93,6 +93,8 @@ class UnstructuredPruner():
if not self._should_prune_layer(sub_layer): if not self._should_prune_layer(sub_layer):
continue continue
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self.skip_params:
continue
t_param = param.value().get_tensor() t_param = param.value().get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
if self.local_sparsity: if self.local_sparsity:
...@@ -111,8 +113,11 @@ class UnstructuredPruner(): ...@@ -111,8 +113,11 @@ class UnstructuredPruner():
def _update_masks(self): def _update_masks(self):
for name, sub_layer in self.model.named_sublayers(): for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer): continue if not self._should_prune_layer(sub_layer):
continue
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self.skip_params:
continue
mask = self.masks.get(param.name) mask = self.masks.get(param.name)
if self.local_sparsity: if self.local_sparsity:
bool_tmp = ( bool_tmp = (
...@@ -243,7 +248,7 @@ class UnstructuredPruner(): ...@@ -243,7 +248,7 @@ class UnstructuredPruner():
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
cond = len(param.shape) == 4 and param.shape[ cond = len(param.shape) == 4 and param.shape[
2] == 1 and param.shape[3] == 1 2] == 1 and param.shape[3] == 1
if not cond: skip_params.add(sub_layer.full_name()) if not cond: skip_params.add(param.name)
return skip_params return skip_params
def _should_prune_layer(self, layer): def _should_prune_layer(self, layer):
......
...@@ -254,7 +254,6 @@ class UnstructuredPruner(): ...@@ -254,7 +254,6 @@ class UnstructuredPruner():
if 'norm' in op.type() and 'grad' not in op.type(): if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
print(skip_params)
return skip_params return skip_params
def _get_skip_params_conv1x1(self, program): def _get_skip_params_conv1x1(self, program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册