diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index c46fd75dd3220abffcaabcadc78b271e48cb5489..361a3af13db508a1d1b697b9136e79d065d00a52 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -25,7 +25,13 @@ from .prune_walker import * from ..prune import prune_walker from .prune_io import * from ..prune import prune_io +from .group_param import * +from ..prune import group_param +from .criterion import * +from ..prune import criterion +from .idx_selector import * +from ..prune import idx_selector __all__ = [] __all__ += pruner.__all__ @@ -34,3 +40,6 @@ __all__ += sensitive_pruner.__all__ __all__ += sensitive.__all__ __all__ += prune_walker.__all__ __all__ += prune_io.__all__ +__all__ += group_param.__all__ +__all__ += criterion.__all__ +__all__ += idx_selector.__all__ diff --git a/paddleslim/prune/criterion.py b/paddleslim/prune/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..36410ee8d9f242c7a5f8f6d11977b3291212da02 --- /dev/null +++ b/paddleslim/prune/criterion.py @@ -0,0 +1,115 @@ +"""Define some functions to compute the importance of structure to be pruned. +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from ..common import get_logger +from ..core import Registry + +__all__ = ["l1_norm", "CRITERION"] + +_logger = get_logger(__name__, level=logging.INFO) + +CRITERION = Registry('criterion') + + +@CRITERION.register +def l1_norm(group, graph): + """Compute l1-norm scores of parameter on given axis. + + This function return a list of parameters' l1-norm scores on given axis. + Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name + and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`. + + Args: + group(list): A group of parameters. The first parameter of the group is convolution layer's weight + while the others are parameters affected by pruning the first one. Each parameter in group + is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and + and `values` is the values of parameter and `axis` is the axis reducing on pruning on. + Returns: + list: A list of tuple storing l1-norm on given axis. + """ + scores = [] + for name, value, axis in group: + + reduce_dims = [i for i in range(len(value.shape)) if i != axis] + score = np.sum(np.abs(value), axis=tuple(reduce_dims)) + scores.append((name, axis, score)) + + return scores + + +@CRITERION.register +def geometry_median(group, graph): + scores = [] + name, value, axis = group[0] + assert (len(value.shape) == 4) + w = value.view() + channel_num = value.shape[0] + w.shape = value.shape[0], np.product(value.shape[1:]) + x = w.repeat(channel_num, axis=0) + y = np.tile(channel_num, (channel_num, 1)) + tmp = np.sqrt(np.sum((x - y)**2, -1)) + tmp = tmp.reshape((channel_num, channel_num)) + tmp = np.sum(tmp, -1) + + for name, value, axis in group: + scores.append(name, axis, tmp) + return scores + + +@CRITERION.register +def bn_scale(group, graph): + """Compute l1-norm scores of parameter on given axis. + + This function return a list of parameters' l1-norm scores on given axis. + Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name + and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`. + + Args: + group(list): A group of parameters. The first parameter of the group is convolution layer's weight + while the others are parameters affected by pruning the first one. Each parameter in group + is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and + and `values` is the values of parameter and `axis` is the axis reducing on pruning on. + Returns: + list: A list of tuple storing l1-norm on given axis. + """ + assert (isinstance(graph, GraphWrapper)) + + # step1: Get first convolution + conv_weight, value, axis = group[0] + param_var = graph.var(conv_weight) + conv_op = param_var.outputs()[0] + + # step2: Get bn layer after first convolution + conv_output = conv_op.outputs("Output")[0] + bn_op = conv_output.outputs()[0] + if bn_op is not None: + bn_scale_param = bn_op.inputs("Scale")[0].name() + else: + raise SystemExit("Can't find BatchNorm op after Conv op in Network.") + + # steps3: Find scale of bn + score = None + for name, value, aixs in group: + if bn_scale_param == name: + score = np.abs(value.reshape([-1])) + + scores = [] + for name, value, axis in group: + scores.append((name, axis, score)) + + return scores diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py new file mode 100644 index 0000000000000000000000000000000000000000..52075c9a47d34723d0f90b8c69b982a610aeb2f7 --- /dev/null +++ b/paddleslim/prune/group_param.py @@ -0,0 +1,79 @@ +"""Define some functions to collect ralated parameters into groups.""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core import GraphWrapper +from .prune_walker import conv2d as conv2d_walker + +__all__ = ["collect_convs"] + + +def collect_convs(params, graph): + """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. + A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. + + .. code-block:: text + + conv1->conv2->conv3->conv4 + + As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as: + + .. code-block:: python + + [("conv1", 0), ("conv2", 1)] + + If `params` is `["conv1", "conv2"]`, then the returned groups is: + + .. code-block:: python + + [[("conv1", 0), ("conv2", 1)], + [("conv2", 0), ("conv3", 1)]] + + Args: + params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. + graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups. + + Returns: + list>: The groups. + + """ + if not isinstance(graph, GraphWrapper): + graph = GraphWrapper(graph) + groups = [] + for param in params: + visited = {} + pruned_params = [] + param = graph.var(param) + conv_op = param.outputs()[0] + walker = conv2d_walker( + conv_op, pruned_params=pruned_params, visited=visited) + walker.prune(param, pruned_axis=0, pruned_idx=[]) + groups.append(pruned_params) + visited = set() + uniq_groups = [] + for group in groups: + repeat_group = False + simple_group = [] + for param, axis, _ in group: + param = param.name() + if axis == 0: + if param in visited: + repeat_group = True + else: + visited.add(param) + simple_group.append((param, axis)) + if not repeat_group: + uniq_groups.append(simple_group) + + return uniq_groups diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..b17348ea6f3e8866f0a9df50188703d023008668 --- /dev/null +++ b/paddleslim/prune/idx_selector.py @@ -0,0 +1,117 @@ +"""Define some functions to sort substructures of parameter by importance. +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from ..core import GraphWrapper +from ..common import get_logger +from ..core import Registry + +__all__ = ["IDX_SELECTOR"] + +IDX_SELECTOR = Registry('idx_selector') + + +@IDX_SELECTOR.register +def default_idx_selector(group, ratio): + """Get the pruned indexes by given ratio. + + This function return a list of parameters' pruned indexes on given axis. + Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name + and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. + + Args: + group(list): A group of parameters. The first parameter of the group is convolution layer's weight + while the others are parameters affected by pruning the first one. Each parameter in group + is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and + `axis` is the axis pruning on and `score` is a np.array storing the importance of strucure + on `axis`. Show as below: + + .. code-block: text + + [("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])] + + The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so + `[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights" + while axis is 0. + + Returns: + + list: pruned indexes + + """ + assert (isinstance(graph, GraphWrapper)) + name, axis, score = group[ + 0] # sort channels by the first convolution's score + sorted_idx = score.argsort() + + pruned_num = len(sorted_idx) * ratio + pruned_idx = sorted_idx[:pruned_num] + + idxs = [] + for name, axis, score in group: + idxs.append((name, axis, pruned_idx)) + return idxs + + +@IDX_SELECTOR.register +def optimal_threshold(group, ratio): + """Get the pruned indexes by given ratio. + + This function return a list of parameters' pruned indexes on given axis. + Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name + and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. + + Args: + group(list): A group of parameters. The first parameter of the group is convolution layer's weight + while the others are parameters affected by pruning the first one. Each parameter in group + is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and + `axis` is the axis pruning on and `score` is a np.array storing the importance of strucure + on `axis`. Show as below: + + .. code-block: text + + [("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])] + + The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so + `[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights" + while axis is 0. + + Returns: + + list: pruned indexes + + """ + assert (isinstance(graph, GraphWrapper)) + name, axis, score = group[ + 0] # sort channels by the first convolution's score + + score[scoew < 1e-18] = 1e-18 + score_sorted = np.sort(score) + score_square = score_sorted**2 + total_sum = score_square.sum() + acc_sum = 0 + for i in range(score_square.size): + acc_sum += score_square[i] + if acc_sum / total_sum > ratio: + break + th = (score_sorted[i - 1] + score_sorted[i]) / 2 if i > 0 else 0 + + pruned_idx = np.squeeze(np.argwhere(score < th)) + + idxs = [] + for name, axis, score in group: + idxs.append((name, axis, pruned_idx)) + return idxs diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 436c13b3c033acd7df805162ecd8083f686715aa..317c5a9c914075ad17df240ddf081259a3954872 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -19,7 +19,9 @@ from functools import reduce import paddle.fluid as fluid import copy from ..core import VarWrapper, OpWrapper, GraphWrapper -from .prune_walker import conv2d as conv2d_walker +from .group_param import collect_convs +from .criterion import CRITERION +from .idx_selector import IDX_SELECTOR from ..common import get_logger __all__ = ["Pruner"] @@ -31,12 +33,26 @@ class Pruner(): """The pruner used to prune channels of convolution. Args: - criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently. + criterion(str|function): the criterion used to sort channels for pruning. + idx_selector(str|function): """ - def __init__(self, criterion="l1_norm"): + def __init__(self, + criterion="l1_norm", + idx_selector="default_idx_selector"): self.criterion = criterion + self.channel_sortor = channel_sortor + if isinstance(criterion, str): + self.criterion = CRITERION.get(criterion) + else: + self.criterion = criterion + if isinstance(idx_selector, str): + self.idx_selector = IDX_SELECTOR.get(idx_selector) + else: + self.idx_selector = idx_selector + + self.pruned_weights = False def prune(self, program, @@ -76,30 +92,35 @@ class Pruner(): visited = {} pruned_params = [] for param, ratio in zip(params, ratios): - if only_graph: + group = collect_convs([param], graph)[0] # [(name, axis)] + if only_graph and self.idx_selector.__name__ == "default_idx_selector": + param_v = graph.var(param) pruned_num = int(round(param_v.shape()[0] * ratio)) - if self.criterion == "optimal_threshold": - pruned_idx = self._cal_pruned_idx( - graph, scope, param, ratio, axis=0) - else: - pruned_idx = [0] * pruned_num + pruned_idx = [0] * pruned_num + for name, aixs in group: + pruned_params.append((name, axis, pruned_idx)) + else: - pruned_idx = self._cal_pruned_idx( - graph, scope, param, ratio, axis=0) - param = graph.var(param) - conv_op = param.outputs()[0] - walker = conv2d_walker( - conv_op, pruned_params=pruned_params, visited=visited) - walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx) + assert ((not self.pruned_weights), + "The weights have been pruned once.") + group_values = [] + for name, axis in group: + values = np.array(scope.find_var(name).get_tensor()) + group_values.append((name, values, axis)) + + scores = self.criterion(group_with_values, + graph) # [(name, axis, score)] + + pruned_params = self.idx_selector(scores) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: - if param.name() not in merge_pruned_params: - merge_pruned_params[param.name()] = {} - if pruned_axis not in merge_pruned_params[param.name()]: - merge_pruned_params[param.name()][pruned_axis] = [] - merge_pruned_params[param.name()][pruned_axis].append(pruned_idx) + if param not in merge_pruned_params: + merge_pruned_params[param] = {} + if pruned_axis not in merge_pruned_params[param]: + merge_pruned_params[param][pruned_axis] = [] + merge_pruned_params[param][pruned_axis].append(pruned_idx) for param_name in merge_pruned_params: for pruned_axis in merge_pruned_params[param_name]: @@ -134,86 +155,9 @@ class Pruner(): param_t.set(pruned_param, place) graph.update_groups_of_conv() graph.infer_shape() + self.pruned_weights = (not only_graph) return graph.program, param_backup, param_shape_backup - def _cal_pruned_idx(self, graph, scope, param, ratio, axis): - """ - Calculate the index to be pruned on axis by given pruning ratio. - - Args: - name(str): The name of parameter to be pruned. - param(np.array): The data of parameter to be pruned. - ratio(float): The ratio to be pruned. - axis(int): The axis to be used for pruning given parameter. - If it is None, the value in self.pruning_axis will be used. - default: None. - - Returns: - list: The indexes to be pruned on axis. - """ - if self.criterion == 'l1_norm': - param_t = np.array(scope.find_var(param).get_tensor()) - prune_num = int(round(param_t.shape[axis] * ratio)) - reduce_dims = [i for i in range(len(param_t.shape)) if i != axis] - criterions = np.sum(np.abs(param_t), axis=tuple(reduce_dims)) - pruned_idx = criterions.argsort()[:prune_num] - elif self.criterion == 'geometry_median': - param_t = np.array(scope.find_var(param).get_tensor()) - prune_num = int(round(param_t.shape[axis] * ratio)) - - def get_distance_sum(param, out_idx): - w = param.view() - reduce_dims = reduce(lambda x, y: x * y, param.shape[1:]) - w.shape = param.shape[0], reduce_dims - selected_filter = np.tile(w[out_idx], (w.shape[0], 1)) - x = w - selected_filter - x = np.sqrt(np.sum(x * x, -1)) - return x.sum() - - dist_sum_list = [] - for out_i in range(param_t.shape[0]): - dist_sum = get_distance_sum(param_t, out_i) - dist_sum_list.append((dist_sum, out_i)) - min_gm_filters = sorted( - dist_sum_list, key=lambda x: x[0])[:prune_num] - pruned_idx = np.array([x[1] for x in min_gm_filters]) - - elif self.criterion == "batch_norm_scale" or self.criterion == "optimal_threshold": - param_var = graph.var(param) - conv_op = param_var.outputs()[0] - conv_output = conv_op.outputs("Output")[0] - bn_op = conv_output.outputs()[0] - if bn_op is not None: - bn_scale_param = bn_op.inputs("Scale")[0].name() - bn_scale_np = np.array( - scope.find_var(bn_scale_param).get_tensor()) - if self.criterion == "batch_norm_scale": - prune_num = int(round(bn_scale_np.shape[axis] * ratio)) - pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num] - elif self.criterion == "optimal_threshold": - - def get_optimal_threshold(weight, percent=0.001): - weight[weight < 1e-18] = 1e-18 - weight_sorted = np.sort(weight) - weight_square = weight_sorted**2 - total_sum = weight_square.sum() - acc_sum = 0 - for i in range(weight_square.size): - acc_sum += weight_square[i] - if acc_sum / total_sum > percent: - break - th = (weight_sorted[i - 1] + weight_sorted[i] - ) / 2 if i > 0 else 0 - return th - - optimal_th = get_optimal_threshold(bn_scale_np, ratio) - pruned_idx = np.squeeze( - np.argwhere(bn_scale_np < optimal_th)) - else: - raise SystemExit( - "Can't find BatchNorm op after Conv op in Network.") - return pruned_idx - def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): """ Pruning a array by indexes on given axis. diff --git a/tests/test_group_param.py b/tests/test_group_param.py new file mode 100644 index 0000000000000000000000000000000000000000..cd699bfd68bf2d28e670a03f9200944be9b1a562 --- /dev/null +++ b/tests/test_group_param.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from layers import conv_bn_layer +from paddleslim.prune import collect_convs + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + groups = collect_convs( + ["conv1_weights", "conv2_weights", "conv3_weights"], main_program) + self.assertTrue(len(groups) == 2) + self.assertTrue(len(groups[0]) == 18) + self.assertTrue(len(groups[1]) == 6) + + +if __name__ == '__main__': + unittest.main()