diff --git a/paddleslim/core/registry.py b/paddleslim/core/registry.py index b746e5089f4c37a9d059d2c05fc7be44fc92a957..8d222cf0199b9d98646b261ad6b79bf3819e525c 100644 --- a/paddleslim/core/registry.py +++ b/paddleslim/core/registry.py @@ -25,9 +25,6 @@ class Registry(object): return self._module_dict.get(key, None) def _register_module(self, module_class): - if not inspect.isclass(module_class): - raise TypeError('module must be a class, but receive {}.'.format( - type(module_class))) module_name = module_class.__name__ if module_name in self._module_dict: raise KeyError('{} is already registered in {}.'.format( diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py index b17348ea6f3e8866f0a9df50188703d023008668..58cf1111aad35bc5101a6ab5a670c7ef1235360f 100644 --- a/paddleslim/prune/idx_selector.py +++ b/paddleslim/prune/idx_selector.py @@ -52,12 +52,11 @@ def default_idx_selector(group, ratio): list: pruned indexes """ - assert (isinstance(graph, GraphWrapper)) name, axis, score = group[ 0] # sort channels by the first convolution's score sorted_idx = score.argsort() - pruned_num = len(sorted_idx) * ratio + pruned_num = int(round(len(sorted_idx) * ratio)) pruned_idx = sorted_idx[:pruned_num] idxs = [] @@ -94,7 +93,6 @@ def optimal_threshold(group, ratio): list: pruned indexes """ - assert (isinstance(graph, GraphWrapper)) name, axis, score = group[ 0] # sort channels by the first convolution's score diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 09bb547e6bf81fc16d45988b8a7831abeeaa1e92..6e85be11a3f53926b4d44ad89f641233367f4a4d 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -71,8 +71,11 @@ class PruneWorker(object): if visited is not None: self.visited = visited cls = PRUNE_WORKER.get(op.type()) - assert cls is not None, "The walker of {} is not registered.".format( - op.type()) + if cls is None: + _logger.warn( + "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". + format(op.type())) + cls = PRUNE_WORKER.get("default_walker") _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( self.op, op, pruned_axis, var.name())) walker = cls(op, @@ -236,6 +239,7 @@ class elementwise_mul(elementwise_op): super(elementwise_mul, self).__init__(op, pruned_params, visited) +@PRUNE_WORKER.register class activation(PruneWorker): def __init__(self, op, pruned_params, visited): super(activation, self).__init__(op, pruned_params, visited) @@ -256,6 +260,27 @@ class activation(PruneWorker): self._prune_op(op, out_var, pruned_axis, pruned_idx) +@PRUNE_WORKER.register +class default_walker(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(default_walker, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.all_outputs(): + for in_var in self.op.inputs(): + if len(in_var.shape()) == len(var.shape()): + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + for out_var in self.op.all_outputs(): + if len(out_var.shape()) == len(var.shape()): + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + @PRUNE_WORKER.register class uniform_random_batch_size_like(activation): def __init__(self, op, pruned_params, visited): diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 317c5a9c914075ad17df240ddf081259a3954872..8169c56ba30cb3ad7c604fce389ef34028a4ae7e 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -41,8 +41,6 @@ class Pruner(): def __init__(self, criterion="l1_norm", idx_selector="default_idx_selector"): - self.criterion = criterion - self.channel_sortor = channel_sortor if isinstance(criterion, str): self.criterion = CRITERION.get(criterion) else: @@ -98,7 +96,7 @@ class Pruner(): param_v = graph.var(param) pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_idx = [0] * pruned_num - for name, aixs in group: + for name, axis in group: pruned_params.append((name, axis, pruned_idx)) else: @@ -109,10 +107,10 @@ class Pruner(): values = np.array(scope.find_var(name).get_tensor()) group_values.append((name, values, axis)) - scores = self.criterion(group_with_values, + scores = self.criterion(group_values, graph) # [(name, axis, score)] - pruned_params = self.idx_selector(scores) + pruned_params.extend(self.idx_selector(scores, ratio)) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: