From a3b9d1184b1afc88bd4d3e55aeb628bcfee45b03 Mon Sep 17 00:00:00 2001 From: chenluyan Date: Wed, 19 Aug 2020 10:16:12 +0800 Subject: [PATCH] (1)add ops for prune_walker;(2)fix bugs of pruning mul --- paddleslim/prune/prune_walker.py | 56 +++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 189b75a7..7d2cafea 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -550,7 +550,7 @@ class mul(PruneWorker): self.pruned_params.append((param_var, 0, idx)) for op in param_var.outputs(): - self._prune_op(op, param_var, 0, pruned_idx) + self._prune_op(op, param_var, 0, idx) @PRUNE_WORKER.register @@ -623,3 +623,57 @@ class affine_channel(PruneWorker): next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class roi_align(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(roi_align, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.outputs("Out"): + in_var = self.op.inputs("X")[0] + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + out_var = self.op.outputs("Out")[0] + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + +@PRUNE_WORKER.register +class lod_reset(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(lod_reset, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.outputs("Out"): + in_var = self.op.inputs("X")[0] + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + out_var = self.op.outputs("Out")[0] + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + +@PRUNE_WORKER.register +class gather(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(gather, self).__init__(op, pruned_params, visited) + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.outputs("Out"): + in_var = self.op.inputs("X")[0] + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + out_var = self.op.outputs("Out")[0] + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) \ No newline at end of file -- GitLab