diff --git a/paddleslim/prune/prune_worker.py b/paddleslim/prune/prune_worker.py index 25703c66db44af5363df511fd59080ac922cc317..c9c69f622e93b04804c133b8c1fa5612e186f5f6 100644 --- a/paddleslim/prune/prune_worker.py +++ b/paddleslim/prune/prune_worker.py @@ -522,7 +522,6 @@ class depthwise_conv2d(PruneWorker): channel_axis = 1 if data_format == "NHWC": channel_axis = 3 - if var == _in_var: assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) @@ -533,7 +532,6 @@ class depthwise_conv2d(PruneWorker): "repeat": repeat }]) # kernel_number * groups will be pruned by reducing groups - self.append_pruned_vars(_filter, 1, transforms) self._visit_and_search(_filter, 0, transforms + [{ "repeat": repeat }]) @@ -546,14 +544,13 @@ class depthwise_conv2d(PruneWorker): }]) elif var == _filter: assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." - self.append_pruned_vars(_filter, 1, transforms) + self.append_pruned_vars(_filter, 0, transforms) self._visit_and_search(_in_var, channel_axis, transforms) self._visit_and_search(_out, channel_axis, transforms) elif var == _out: assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) self.append_pruned_vars(_filter, 0, transforms) - self.append_pruned_vars(_filter, 1, transforms) self._visit_and_search(_filter, 0, transforms) # It will not pruning number of kernels in depthwise conv2d, # so it is not neccesary to search succeed operators. diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 1926c8cbb03fdbd567e70f6601b4d006ded9869c..4c58c2e1de5445ccfac5e91261186a5e797e4252 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -117,7 +117,7 @@ class Pruner(): _groups = 1 if not lazy: # update groups of conv2d - if pruned_axis == 1: + if pruned_axis == 1 or pruned_axis == 0: for op in param.outputs(): if op.type() in [ "conv2d", "conv2d_grad", "depthwise_conv2d", @@ -132,7 +132,7 @@ class Pruner(): f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};" ) op.set_attr("groups", new_groups) - if _groups == 1: + origin_shape = copy.deepcopy(param.shape()) if param_shape_backup is not None: param_shape_backup[param.name()] = origin_shape diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index e0375c8b6038bfd9a2db1b16f0cf7d03cc98be8a..f2f2d2fc1aa403ab15d23281da2a26a62f6a2d42 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -473,9 +473,9 @@ class TestDepthwiseConv2d(TestPruneWorker): def set_cases(self): weight_var = self.graph.var('conv1.w_0') - self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]})) - self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]})) - self.cases.append((weight_var, 0, {'conv1.w_0': [1]})) + self.cases.append((self.in_var, 1, {'conv1.w_0': [0]})) + self.cases.append((self.out_var, 1, {'conv1.w_0': [0]})) + self.cases.append((weight_var, 0, {'conv1.w_0': [0]})) def test_prune(self): self.check_in_out()