未验证 提交 8b156124 编写于 作者: Z zhouzj 提交者: GitHub

Fix the bug of dygraph pruner. (#1320)

上级 de25946d
......@@ -220,7 +220,7 @@ class PruningPlan():
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len(
if groups is not None and groups > 1 and len(
value.shape) == 4:
filter_size = value.shape[1]
except_num = np.sum(bool_mask)
......@@ -230,7 +230,6 @@ class PruningPlan():
sub_layer._groups = new_groups
_logger.info("change groups from {} to {} for {}.".
format(groups, new_groups, param.name))
continue
# The name of buffer can not contains "."
backup_name = param.name.replace(".", "_") + "_backup"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册