From 8b156124638caf851e728f49cf46afb9ebf7ed4e Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Tue, 2 Aug 2022 10:14:03 +0800 Subject: [PATCH] Fix the bug of dygraph pruner. (#1320) --- paddleslim/dygraph/prune/pruning_plan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddleslim/dygraph/prune/pruning_plan.py b/paddleslim/dygraph/prune/pruning_plan.py index d9cd8e4a..cd669ffd 100644 --- a/paddleslim/dygraph/prune/pruning_plan.py +++ b/paddleslim/dygraph/prune/pruning_plan.py @@ -220,7 +220,7 @@ class PruningPlan(): t_value = param.value().get_tensor() value = np.array(t_value).astype("float32") groups = _mask._op.attr('groups') - if dims == 1 and groups is not None and groups > 1 and len( + if groups is not None and groups > 1 and len( value.shape) == 4: filter_size = value.shape[1] except_num = np.sum(bool_mask) @@ -230,7 +230,6 @@ class PruningPlan(): sub_layer._groups = new_groups _logger.info("change groups from {} to {} for {}.". format(groups, new_groups, param.name)) - continue # The name of buffer can not contains "." backup_name = param.name.replace(".", "_") + "_backup" -- GitLab