"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from..coreimportGraphWrapper
from.prune_walkerimportconv2dasconv2d_walker
__all__=["collect_convs"]
defcollect_convs(params,graph):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups.