diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 6e85be11a3f53926b4d44ad89f641233367f4a4d..c695a3dd37bf2b3369f532053515f30281fb5931 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -23,6 +23,8 @@ _logger = get_logger(__name__, level=logging.INFO) PRUNE_WORKER = Registry('prune_worker') +SKIP_OPS = ["conditional_block"] + class PruneWorker(object): def __init__(self, op, pruned_params=[], visited={}): @@ -72,6 +74,9 @@ class PruneWorker(object): self.visited = visited cls = PRUNE_WORKER.get(op.type()) if cls is None: + if op.type() in SKIP_OPS: + _logger.warn("Skip operator [{}]".format(op.type())) + return _logger.warn( "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". format(op.type())) @@ -149,6 +154,71 @@ class conv2d(PruneWorker): self._prune_op(op, output_var, channel_axis, pruned_idx) +@PRUNE_WORKER.register +class conv2d_transpose(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(conv2d_transpose, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + data_format = self.op.attr("data_format") + channel_axis = 1 + if data_format == "NHWC": + channel_axis = 3 + if var in self.op.inputs("Input"): + assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( + pruned_axis, var.name()) + filter_var = self.op.inputs("Filter")[0] + self._visit(filter_var, 0) + self.pruned_params.append((filter_var, 0, pruned_idx)) + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 0, pruned_idx) + + elif var in self.op.inputs("Filter"): + assert pruned_axis in [0, 1] + + self.pruned_params.append((var, pruned_axis, pruned_idx)) + + for op in var.outputs(): + self._prune_op(op, var, pruned_axis, pruned_idx) + + if pruned_axis == 1: + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias"), channel_axis, pruned_idx)) + output_var = self.op.outputs("Output")[0] + self._visit(output_var, channel_axis) + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + elif pruned_axis == 0: + input_var = self.op.inputs("Input")[0] + self._visit(input_var, channel_axis) + pre_ops = input_var.inputs() + for op in pre_ops: + self._prune_op(op, input_var, channel_axis, pruned_idx) + elif var in self.op.outputs("Output"): + assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( + pruned_axis, var.name()) + + filter_var = self.op.inputs("Filter")[0] + self._visit(filter_var, 1) + + self.pruned_params.append((filter_var, 1, pruned_idx)) + + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 1, pruned_idx) + + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) + + output_var = self.op.outputs("Output")[0] + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + @PRUNE_WORKER.register class batch_norm(PruneWorker): def __init__(self, op, pruned_params, visited): @@ -267,7 +337,7 @@ class default_walker(PruneWorker): def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.all_outputs(): - for in_var in self.op.inputs(): + for in_var in self.op.all_inputs(): if len(in_var.shape()) == len(var.shape()): pre_ops = in_var.inputs() for op in pre_ops: @@ -549,3 +619,33 @@ class adam(PruneWorker): self.pruned_params.append((moment1_var, pruned_axis, pruned_idx)) moment2_var = self.op.inputs("Moment2")[0] self.pruned_params.append((moment2_var, pruned_axis, pruned_idx)) + + +@PRUNE_WORKER.register +class affine_channel(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(affine_channel, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if (var not in self.op.outputs("Out")) and ( + var not in self.op.inputs("X")): + return + + if var in self.op.outputs("Out"): + in_var = self.op.inputs("X")[0] + self._visit(in_var, pruned_axis) + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + for param in ["Scale", "Bias"]: + param_var = self.op.inputs(param)[0] + for op in param_var.outputs(): + self._prune_op(op, param_var, 0, pruned_idx) + self.pruned_params.append((param_var, 0, pruned_idx)) + + out_var = self.op.outputs("Out")[0] + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 52f388ac93523d400586b244523604e6758bd843..0e1a54572d75c23fc339fda8e07478247e35429c 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -90,6 +90,11 @@ class Pruner(): visited = {} pruned_params = [] for param, ratio in zip(params, ratios): + if graph.var(param) is None: + _logger.warn( + "Variable[{}] to be pruned is not in current graph.". + format(param)) + continue group = collect_convs([param], graph, visited)[0] # [(name, axis)] if group is None or len(group) == 0: continue