From 9e14508cf98df177646053b5d0397ef7c90c79f6 Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Tue, 17 Jan 2023 18:44:27 +0800 Subject: [PATCH] fix the bug of flatten op in pruning. (#1639) --- paddleslim/prune/prune_worker.py | 45 +++++++++++++++-------------- tests/dygraph/test_filter_pruner.py | 7 ++--- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/paddleslim/prune/prune_worker.py b/paddleslim/prune/prune_worker.py index a2db9ee9..0b28e2b3 100644 --- a/paddleslim/prune/prune_worker.py +++ b/paddleslim/prune/prune_worker.py @@ -89,8 +89,8 @@ class PruneWorker(object): transforms(list): The transforms applied the the current variable/mask. """ if var.name() in self.skip_vars: - raise UnsupportOpError("Variable {} was skipped.".format(var.name( - ))) + raise UnsupportOpError( + "Variable {} was skipped.".format(var.name())) if self._visit(var, pruned_axis): self._prune(var, pruned_axis, transforms) @@ -109,8 +109,8 @@ class PruneWorker(object): def _visit_and_search(self, var, axis, transforms): self._visit(var, axis) if var.name() in self.skip_vars: - raise UnsupportOpError("Variable {} was skipped.".format(var.name( - ))) + raise UnsupportOpError( + "Variable {} was skipped.".format(var.name())) pre_ops = var.inputs() for op in pre_ops: self._prune_op(op, var, axis, transforms) @@ -127,8 +127,8 @@ class PruneWorker(object): if visited is not None: self.visited = visited if op.type() in self.ops_unsupported: - raise UnsupportOpError("Unsupported operator named {}".format( - op.type())) + raise UnsupportOpError( + "Unsupported operator named {}".format(op.type())) cls = PRUNE_WORKER.get(op.type()) if cls is None: if op.type() in SKIPPED_OPS: @@ -136,8 +136,8 @@ class PruneWorker(object): if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger: cls = PRUNE_WORKER.get("default_worker") else: - raise UnsupportOpError("Unsupported operator named {}".format( - op.type())) + raise UnsupportOpError( + "Unsupported operator named {}".format(op.type())) _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}". format(self.op, op, pruned_axis, var.name(), transforms)) @@ -662,12 +662,13 @@ class depthwise_conv2d(PruneWorker): "repeat": repeat }]) # It will not pruning number of kernels in depthwise conv2d, - # so it is not neccesary to search succeed operators. + # so it is not neccesary to search succeed operators. # self._visit_and_search(_filter, 1, transforms) self._visit(_filter, 1) - self._visit_and_search(_out, channel_axis, transforms + [{ - "repeat": repeat - }]) + self._visit_and_search(_out, channel_axis, + transforms + [{ + "repeat": repeat + }]) elif var == _filter: assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." self.append_pruned_vars(_filter, 0, transforms) @@ -679,7 +680,7 @@ class depthwise_conv2d(PruneWorker): self.append_pruned_vars(_filter, 0, transforms) self._visit_and_search(_filter, 0, transforms) # It will not pruning number of kernels in depthwise conv2d, - # so it is not neccesary to search succeed operators. + # so it is not neccesary to search succeed operators. # self._visit_and_search(_filter, 1, transforms) self._visit(_filter, 1) self._visit_and_search(_in_var, channel_axis, transforms) @@ -733,8 +734,9 @@ class mul(PruneWorker): }]) elif var == y: if (pruned_axis < y_num_col_dims) and ( - 1 < len(x_shape) - x_num_col_dims) and max(x_shape[ - x_num_col_dims:]) != np.prod(y_shape[:y_num_col_dims]): + 1 < len(x_shape) - x_num_col_dims + ) and max(x_shape[x_num_col_dims:]) != np.prod( + y_shape[:y_num_col_dims]): raise UnsupportOpError( "Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims." ) @@ -763,8 +765,8 @@ class mul(PruneWorker): tile *= y_shape[i] for i in range(pruned_axis + 1, y_num_col_dims): repeat *= y_shape[i] - new_pruned_axis = int(np.argmax(x_shape[ - x_num_col_dims:])) + x_num_col_dims + new_pruned_axis = int( + np.argmax(x_shape[x_num_col_dims:])) + x_num_col_dims self.append_pruned_vars( x, # len(x_shape) - 1, trans + [{ @@ -825,8 +827,8 @@ class matmul(PruneWorker): mappings = [(1, 1, 1)] elif x_shape_len >= 3 and y_shape_len >= 3: mappings = [(x_shape_len - 2, -1, x_shape_len - 2), - (x_shape_len - 1, x_shape_len - 2, -1), - (-1, x_shape_len - 1, x_shape_len - 1)] + (x_shape_len - 1, x_shape_len - 2, + -1), (-1, x_shape_len - 1, x_shape_len - 1)] if var == x: for x_i, y_i, out_i in mappings: if pruned_axis == x_i: @@ -953,8 +955,9 @@ class flatten_contiguous_range(PruneWorker): out_pruned_axis = pruned_axis if pruned_axis >= start_axis and pruned_axis <= stop_axis: out_pruned_axis = start_axis - for i in range(pruned_axis + 1, stop_axis + 1): - stride *= in_var.shape()[i] + for i in range(start_axis, stop_axis + 1): + if i != pruned_axis: + stride *= in_var.shape()[i] elif pruned_axis > stop_axis: out_pruned_axis = start_axis + pruned_axis - stop_axis diff --git a/tests/dygraph/test_filter_pruner.py b/tests/dygraph/test_filter_pruner.py index 90c13478..68294a3a 100644 --- a/tests/dygraph/test_filter_pruner.py +++ b/tests/dygraph/test_filter_pruner.py @@ -149,10 +149,9 @@ class MulNet(paddle.nn.Layer): def forward(self, x): conv_a = self.conv_a(x) - return paddle.fluid.layers.mul(self.b, - conv_a, - x_num_col_dims=1, - y_num_col_dims=3) + tmp = paddle.flatten(conv_a, start_axis=0, stop_axis=2) + res = paddle.matmul(self.b, tmp) + return res class TestPruningMul(unittest.TestCase): -- GitLab