diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index c46fd75dd3220abffcaabcadc78b271e48cb5489..42ea0b4c3923cafcebdd2f1acc684defdd9faddb 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -25,6 +25,8 @@ from .prune_walker import * from ..prune import prune_walker from .prune_io import * from ..prune import prune_io +from .group_param import * +from ..prune import group_param __all__ = [] @@ -34,3 +36,4 @@ __all__ += sensitive_pruner.__all__ __all__ += sensitive.__all__ __all__ += prune_walker.__all__ __all__ += prune_io.__all__ +__all__ += group_param.__all__ diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py new file mode 100644 index 0000000000000000000000000000000000000000..52075c9a47d34723d0f90b8c69b982a610aeb2f7 --- /dev/null +++ b/paddleslim/prune/group_param.py @@ -0,0 +1,79 @@ +"""Define some functions to collect ralated parameters into groups.""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core import GraphWrapper +from .prune_walker import conv2d as conv2d_walker + +__all__ = ["collect_convs"] + + +def collect_convs(params, graph): + """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. + A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. + + .. code-block:: text + + conv1->conv2->conv3->conv4 + + As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as: + + .. code-block:: python + + [("conv1", 0), ("conv2", 1)] + + If `params` is `["conv1", "conv2"]`, then the returned groups is: + + .. code-block:: python + + [[("conv1", 0), ("conv2", 1)], + [("conv2", 0), ("conv3", 1)]] + + Args: + params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. + graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups. + + Returns: + list>: The groups. + + """ + if not isinstance(graph, GraphWrapper): + graph = GraphWrapper(graph) + groups = [] + for param in params: + visited = {} + pruned_params = [] + param = graph.var(param) + conv_op = param.outputs()[0] + walker = conv2d_walker( + conv_op, pruned_params=pruned_params, visited=visited) + walker.prune(param, pruned_axis=0, pruned_idx=[]) + groups.append(pruned_params) + visited = set() + uniq_groups = [] + for group in groups: + repeat_group = False + simple_group = [] + for param, axis, _ in group: + param = param.name() + if axis == 0: + if param in visited: + repeat_group = True + else: + visited.add(param) + simple_group.append((param, axis)) + if not repeat_group: + uniq_groups.append(simple_group) + + return uniq_groups diff --git a/tests/test_group_param.py b/tests/test_group_param.py new file mode 100644 index 0000000000000000000000000000000000000000..cd699bfd68bf2d28e670a03f9200944be9b1a562 --- /dev/null +++ b/tests/test_group_param.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from layers import conv_bn_layer +from paddleslim.prune import collect_convs + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + groups = collect_convs( + ["conv1_weights", "conv2_weights", "conv3_weights"], main_program) + self.assertTrue(len(groups) == 2) + self.assertTrue(len(groups[0]) == 18) + self.assertTrue(len(groups[1]) == 6) + + +if __name__ == '__main__': + unittest.main()