From 7e3e14ae96bd45efb5ffa1676f6d01f7e569f30c Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Fri, 29 Jul 2022 10:02:04 +0800 Subject: [PATCH] Fix the bug of pruning dw_conv. (#1311) --- paddleslim/prune/prune_worker.py | 5 +---- paddleslim/prune/pruner.py | 4 ++-- tests/test_prune_walker.py | 6 +++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/paddleslim/prune/prune_worker.py b/paddleslim/prune/prune_worker.py index 25703c66..c9c69f62 100644 --- a/paddleslim/prune/prune_worker.py +++ b/paddleslim/prune/prune_worker.py @@ -522,7 +522,6 @@ class depthwise_conv2d(PruneWorker): channel_axis = 1 if data_format == "NHWC": channel_axis = 3 - if var == _in_var: assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) @@ -533,7 +532,6 @@ class depthwise_conv2d(PruneWorker): "repeat": repeat }]) # kernel_number * groups will be pruned by reducing groups - self.append_pruned_vars(_filter, 1, transforms) self._visit_and_search(_filter, 0, transforms + [{ "repeat": repeat }]) @@ -546,14 +544,13 @@ class depthwise_conv2d(PruneWorker): }]) elif var == _filter: assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." - self.append_pruned_vars(_filter, 1, transforms) + self.append_pruned_vars(_filter, 0, transforms) self._visit_and_search(_in_var, channel_axis, transforms) self._visit_and_search(_out, channel_axis, transforms) elif var == _out: assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) self.append_pruned_vars(_filter, 0, transforms) - self.append_pruned_vars(_filter, 1, transforms) self._visit_and_search(_filter, 0, transforms) # It will not pruning number of kernels in depthwise conv2d, # so it is not neccesary to search succeed operators. diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 1926c8cb..4c58c2e1 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -117,7 +117,7 @@ class Pruner(): _groups = 1 if not lazy: # update groups of conv2d - if pruned_axis == 1: + if pruned_axis == 1 or pruned_axis == 0: for op in param.outputs(): if op.type() in [ "conv2d", "conv2d_grad", "depthwise_conv2d", @@ -132,7 +132,7 @@ class Pruner(): f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};" ) op.set_attr("groups", new_groups) - if _groups == 1: + origin_shape = copy.deepcopy(param.shape()) if param_shape_backup is not None: param_shape_backup[param.name()] = origin_shape diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index e0375c8b..f2f2d2fc 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -473,9 +473,9 @@ class TestDepthwiseConv2d(TestPruneWorker): def set_cases(self): weight_var = self.graph.var('conv1.w_0') - self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]})) - self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]})) - self.cases.append((weight_var, 0, {'conv1.w_0': [1]})) + self.cases.append((self.in_var, 1, {'conv1.w_0': [0]})) + self.cases.append((self.out_var, 1, {'conv1.w_0': [0]})) + self.cases.append((weight_var, 0, {'conv1.w_0': [0]})) def test_prune(self): self.check_in_out() -- GitLab