diff --git a/paddleslim/prune/criterion.py b/paddleslim/prune/criterion.py index 6a8fb7d913bd10537ee0181be76596869f87b530..a32ec6c94d3b792987e03a99b177b9e7cfef1e95 100644 --- a/paddleslim/prune/criterion.py +++ b/paddleslim/prune/criterion.py @@ -43,11 +43,11 @@ def l1_norm(group, graph): list: A list of tuple storing l1-norm on given axis. """ scores = [] - for name, value, axis in group: + for name, value, axis, pruned_idx in group: reduce_dims = [i for i in range(len(value.shape)) if i != axis] score = np.sum(np.abs(value), axis=tuple(reduce_dims)) - scores.append((name, axis, score)) + scores.append((name, axis, score, pruned_idx)) return scores @@ -55,7 +55,7 @@ def l1_norm(group, graph): @CRITERION.register def geometry_median(group, graph): scores = [] - name, value, axis = group[0] + name, value, axis, _ = group[0] assert (len(value.shape) == 4) def get_distance_sum(value, out_idx): @@ -73,8 +73,8 @@ def geometry_median(group, graph): tmp = np.array(dist_sum_list) - for name, value, axis in group: - scores.append((name, axis, tmp)) + for name, value, axis, idx in group: + scores.append((name, axis, tmp, idx)) return scores @@ -97,7 +97,7 @@ def bn_scale(group, graph): assert (isinstance(graph, GraphWrapper)) # step1: Get first convolution - conv_weight, value, axis = group[0] + conv_weight, value, axis, _ = group[0] param_var = graph.var(conv_weight) conv_op = param_var.outputs()[0] @@ -111,12 +111,12 @@ def bn_scale(group, graph): # steps3: Find scale of bn score = None - for name, value, aixs in group: + for name, value, aixs, _ in group: if bn_scale_param == name: score = np.abs(value.reshape([-1])) scores = [] - for name, value, axis in group: - scores.append((name, axis, score)) + for name, value, axis, idx in group: + scores.append((name, axis, score, idx)) return scores diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index fa9690f35215fad1f54aa246e83af5119ad4c408..c237b84685b90567738c55e8d1611357481aeb2c 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}): conv_op = param.outputs()[0] walker = conv2d_walker( conv_op, pruned_params=pruned_params, visited=visited) - walker.prune(param, pruned_axis=0, pruned_idx=[]) + walker.prune(param, pruned_axis=0, pruned_idx=[0]) groups.append(pruned_params) visited = set() uniq_groups = [] for group in groups: repeat_group = False simple_group = [] - for param, axis, _ in group: + for param, axis, pruned_idx in group: param = param.name() if axis == 0: if param in visited: repeat_group = True else: visited.add(param) - simple_group.append((param, axis)) + simple_group.append((param, axis, pruned_idx)) if not repeat_group: uniq_groups.append(simple_group) diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py index a46a76ab1ba108604ef7a8403a35c50ef3a65350..7e90d1ac11b7f0717cbfc56e9a0987ce1a4da520 100644 --- a/paddleslim/prune/idx_selector.py +++ b/paddleslim/prune/idx_selector.py @@ -52,7 +52,7 @@ def default_idx_selector(group, ratio): list: pruned indexes """ - name, axis, score = group[ + name, axis, score, _ = group[ 0] # sort channels by the first convolution's score sorted_idx = score.argsort() @@ -60,8 +60,9 @@ def default_idx_selector(group, ratio): pruned_idx = sorted_idx[:pruned_num] idxs = [] - for name, axis, score in group: - idxs.append((name, axis, pruned_idx)) + for name, axis, score, offsets in group: + r_idx = [i + offsets[0] for i in pruned_idx] + idxs.append((name, axis, r_idx)) return idxs diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index c695a3dd37bf2b3369f532053515f30281fb5931..9407bdda02d13efeb59345f1d292cd6b0b9f432d 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -77,9 +77,10 @@ class PruneWorker(object): if op.type() in SKIP_OPS: _logger.warn("Skip operator [{}]".format(op.type())) return - _logger.warn( - "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". - format(op.type())) + +# _logger.warn( +# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". +# format(op.type())) cls = PRUNE_WORKER.get("default_walker") _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( self.op, op, pruned_axis, var.name())) @@ -263,6 +264,8 @@ class elementwise_op(PruneWorker): if name == "Y": actual_axis = pruned_axis - axis in_var = self.op.inputs(name)[0] + if len(in_var.shape()) == 1 and in_var.shape()[0] == 1: + continue pre_ops = in_var.inputs() for op in pre_ops: self._prune_op(op, in_var, actual_axis, pruned_idx) @@ -270,19 +273,21 @@ class elementwise_op(PruneWorker): else: if var in self.op.inputs("X"): in_var = self.op.inputs("Y")[0] - - if in_var.is_parameter(): - self.pruned_params.append( - (in_var, pruned_axis - axis, pruned_idx)) - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis - axis, pruned_idx) + if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1): + if in_var.is_parameter(): + self.pruned_params.append( + (in_var, pruned_axis - axis, pruned_idx)) + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis - axis, + pruned_idx) elif var in self.op.inputs("Y"): in_var = self.op.inputs("X")[0] - pre_ops = in_var.inputs() - pruned_axis = pruned_axis + axis - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1): + pre_ops = in_var.inputs() + pruned_axis = pruned_axis + axis + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) out_var = self.op.outputs("Out")[0] self._visit(out_var, pruned_axis) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 0e1a54572d75c23fc339fda8e07478247e35429c..9bc2de09df095a4b50cbbd1be41a345d06667b2f 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -90,12 +90,14 @@ class Pruner(): visited = {} pruned_params = [] for param, ratio in zip(params, ratios): + _logger.info("pruning: {}".format(param)) if graph.var(param) is None: _logger.warn( "Variable[{}] to be pruned is not in current graph.". format(param)) continue - group = collect_convs([param], graph, visited)[0] # [(name, axis)] + group = collect_convs([param], graph, + visited)[0] # [(name, axis, pruned_idx)] if group is None or len(group) == 0: continue if only_graph and self.idx_selector.__name__ == "default_idx_selector": @@ -103,30 +105,33 @@ class Pruner(): param_v = graph.var(param) pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_idx = [0] * pruned_num - for name, axis in group: + for name, axis, _ in group: pruned_params.append((name, axis, pruned_idx)) else: assert ((not self.pruned_weights), "The weights have been pruned once.") group_values = [] - for name, axis in group: + for name, axis, pruned_idx in group: values = np.array(scope.find_var(name).get_tensor()) - group_values.append((name, values, axis)) + group_values.append((name, values, axis, pruned_idx)) - scores = self.criterion(group_values, - graph) # [(name, axis, score)] + scores = self.criterion( + group_values, graph) # [(name, axis, score, pruned_idx)] pruned_params.extend(self.idx_selector(scores, ratio)) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: + print("{}\t{}\t{}".format(param, pruned_axis, len(pruned_idx))) if param not in merge_pruned_params: merge_pruned_params[param] = {} if pruned_axis not in merge_pruned_params[param]: merge_pruned_params[param][pruned_axis] = [] merge_pruned_params[param][pruned_axis].append(pruned_idx) + print("param name: stage.0.conv_layer.conv.weights; idx: {}".format( + merge_pruned_params["stage.0.conv_layer.conv.weights"][1])) for param_name in merge_pruned_params: for pruned_axis in merge_pruned_params[param_name]: pruned_idx = np.concatenate(merge_pruned_params[param_name][