提交 c263531a 编写于 作者: W wanghaoshuang

Add group_param to collect relative covolution into groups.

上级 bc14da69
...@@ -25,6 +25,8 @@ from .prune_walker import * ...@@ -25,6 +25,8 @@ from .prune_walker import *
from ..prune import prune_walker from ..prune import prune_walker
from .prune_io import * from .prune_io import *
from ..prune import prune_io from ..prune import prune_io
from .group_param import *
from ..prune import group_param
__all__ = [] __all__ = []
...@@ -34,3 +36,4 @@ __all__ += sensitive_pruner.__all__ ...@@ -34,3 +36,4 @@ __all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__ __all__ += sensitive.__all__
__all__ += prune_walker.__all__ __all__ += prune_walker.__all__
__all__ += prune_io.__all__ __all__ += prune_io.__all__
__all__ += group_param.__all__
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core import GraphWrapper
from .prune_walker import conv2d as conv2d_walker
__all__ = ["collect_convs"]
def collect_convs(params, graph):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
groups = []
for param in params:
visited = {}
pruned_params = []
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, _ in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis))
if not repeat_group:
uniq_groups.append(simple_group)
return uniq_groups
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from layers import conv_bn_layer
from paddleslim.prune import collect_convs
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 18)
self.assertTrue(len(groups[1]) == 6)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册