未验证 提交 8b156124 编写于 作者: Z zhouzj 提交者: GitHub

Fix the bug of dygraph pruner. (#1320)

上级 de25946d
...@@ -220,7 +220,7 @@ class PruningPlan(): ...@@ -220,7 +220,7 @@ class PruningPlan():
t_value = param.value().get_tensor() t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups') groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len( if groups is not None and groups > 1 and len(
value.shape) == 4: value.shape) == 4:
filter_size = value.shape[1] filter_size = value.shape[1]
except_num = np.sum(bool_mask) except_num = np.sum(bool_mask)
...@@ -230,7 +230,6 @@ class PruningPlan(): ...@@ -230,7 +230,6 @@ class PruningPlan():
sub_layer._groups = new_groups sub_layer._groups = new_groups
_logger.info("change groups from {} to {} for {}.". _logger.info("change groups from {} to {} for {}.".
format(groups, new_groups, param.name)) format(groups, new_groups, param.name))
continue
# The name of buffer can not contains "." # The name of buffer can not contains "."
backup_name = param.name.replace(".", "_") + "_backup" backup_name = param.name.replace(".", "_") + "_backup"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册