diff --git a/paddleslim/prune/prune_worker.py b/paddleslim/prune/prune_worker.py index e130074f7dbcb2c4b327f54d7d77ec7126ea2ad4..25703c66db44af5363df511fd59080ac922cc317 100644 --- a/paddleslim/prune/prune_worker.py +++ b/paddleslim/prune/prune_worker.py @@ -105,6 +105,9 @@ class PruneWorker(object): def _visit_and_search(self, var, axis, transforms): self._visit(var, axis) + if var.name() in self.skip_vars: + raise UnsupportOpError("Variable {} was skipped.".format(var.name( + ))) pre_ops = var.inputs() for op in pre_ops: self._prune_op(op, var, axis, transforms) @@ -123,7 +126,6 @@ class PruneWorker(object): if op.type() in self.ops_unsupported: raise UnsupportOpError("Unsupported operator named {}".format( op.type())) - cls = PRUNE_WORKER.get(op.type()) if cls is None: if op.type() in SKIPPED_OPS: @@ -214,10 +216,7 @@ class conv2d(PruneWorker): filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 0) self.append_pruned_vars(filter_var, 0, pruned_idx) - - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 0, pruned_idx) - + self._visit_and_search(filter_var, 0, pruned_idx) if len(self.op.inputs("Bias")) > 0: self.append_pruned_vars( self.op.inputs("Bias")[0], channel_axis, pruned_idx) @@ -240,8 +239,7 @@ class conv2d_transpose(PruneWorker): filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 0) self.append_pruned_vars(filter_var, 0, pruned_idx) - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 0, pruned_idx) + self._visit_and_search(filter_var, 0, pruned_idx) elif var in self.op.inputs("Filter"): _logger.warn("Skip pruning output channels of conv2d_transpose!") @@ -252,20 +250,15 @@ class conv2d_transpose(PruneWorker): filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 1) - self.append_pruned_vars(filter_var, 1, pruned_idx) - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 1, pruned_idx) + self._visit_and_search(filter_var, 1, pruned_idx) if len(self.op.inputs("Bias")) > 0: self.append_pruned_vars( self.op.inputs("Bias")[0], channel_axis, pruned_idx) - output_var = self.op.outputs("Output")[0] - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) + self._visit_and_search(output_var, channel_axis, pruned_idx) @PRUNE_WORKER.register @@ -281,22 +274,15 @@ class batch_norm(PruneWorker): if var in self.op.outputs("Y"): in_var = self.op.inputs("X")[0] - self._visit(in_var, pruned_axis) - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) for param in ["Scale", "Bias", "Mean", "Variance"]: param_var = self.op.inputs(param)[0] - for op in param_var.outputs(): - self._prune_op(op, param_var, 0, pruned_idx) + self._visit_and_search(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, pruned_idx) out_var = self.op.outputs("Y")[0] - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -475,20 +461,13 @@ class sum(PruneWorker): def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.outputs("Out"): for in_var in self.op.inputs("X"): - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) elif var in self.op.inputs("X"): for in_var in self.op.inputs("X"): if in_var != var: - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) out_var = self.op.outputs("Out")[0] - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -756,12 +735,10 @@ class scale(PruneWorker): def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.inputs("X"): out_var = self.op.outputs("Out")[0] - for op in out_var.outputs(): - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(out_var, pruned_axis, pruned_idx) elif var in self.op.outputs("Out"): in_var = self.op.inputs("X")[0] - for op in in_var.inputs(): - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -802,22 +779,15 @@ class affine_channel(PruneWorker): if var in self.op.outputs("Out"): in_var = self.op.inputs("X")[0] - self._visit(in_var, pruned_axis) - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) for param in ["Scale", "Bias"]: param_var = self.op.inputs(param)[0] - for op in param_var.outputs(): - self._prune_op(op, param_var, 0, pruned_idx) + self._visit_and_search(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, pruned_idx) out_var = self.op.outputs("Out")[0] - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -843,12 +813,9 @@ class flatten_contiguous_range(PruneWorker): out_pruned_axis = start_axis + pruned_axis - stop_axis self._visit(in_var, pruned_axis) - self._visit(out_var, out_pruned_axis) transform = {'stride': stride} - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, out_pruned_axis, - transforms + [transform]) + self._visit_and_search(out_var, out_pruned_axis, + transforms + [transform]) @PRUNE_WORKER.register diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index bebe5459b320c47fce227a840c156832fa782835..46290633a80ce5ae5040161914de583eba3e4e87 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -201,6 +201,50 @@ class TestSqueeze2(StaticCase): self.assertTrue(ret == {}) +class TestSum(StaticCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[1, 3, 16, 16]) + conv1 = conv_bn_layer( + input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1] + conv2 = conv_bn_layer( + input, 8, 3, "conv2", act='relu') #[1, 8, 1, 1] + out = conv1 + conv2 + + graph = GraphWrapper(main_program) + cls = PRUNE_WORKER.get("sum") + out_var = graph.var(out.name) + in_var = graph.var(conv1.name) + op = out_var.inputs()[0] + # pruning out + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(out_var, 1, []) + for var, axis, _, _ in pruned_params: + ret[var.name()] = axis + self.assertTrue(ret == { + 'conv1_weights': 0, + 'conv1_bn_scale': 0, + 'conv1_bn_offset': 0, + 'conv1_bn_mean': 0, + 'conv1_bn_variance': 0 + }) + + # pruning inputs + pruned_params = [] + worker = cls(op, pruned_params, {}, True) + worker.skip_vars = [out.name] + try: + worker.prune(in_var, 0, []) + except UnsupportOpError as e: + print(e) + self.assertTrue(pruned_params == []) + + class TestUnsupportAndDefault(StaticCase): def test_prune(self): main_program = fluid.Program() @@ -324,7 +368,6 @@ class TestPruneWorker(unittest.TestCase): if var.name() not in ret: ret[var.name()] = [] ret[var.name()].append(axis) - print(f"excepted: {_ret}; but get {ret}") self.assertTrue(ret == _ret)