group_param.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
whs 已提交
16
import logging
17
from ..core import GraphWrapper
W
whs 已提交
18
from ..common import get_logger
19
from .prune_walker import PRUNE_WORKER
20 21 22

__all__ = ["collect_convs"]

W
whs 已提交
23 24
_logger = get_logger(__name__, level=logging.INFO)

25

26
def collect_convs(params, graph, visited={}):
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
    """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
    A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.

    .. code-block:: text

       conv1->conv2->conv3->conv4

    As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:

    .. code-block:: python

       [("conv1", 0), ("conv2", 1)]

    If `params` is `["conv1", "conv2"]`, then the returned groups is:

    .. code-block:: python

       [[("conv1", 0), ("conv2", 1)],
        [("conv2", 0), ("conv3", 1)]]

    Args:
       params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
Y
yukavio 已提交
49
       graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
50 51 52 53 54 55 56 57 58 59 60

    Returns:
       list<list<tuple>>: The groups.

    """
    if not isinstance(graph, GraphWrapper):
        graph = GraphWrapper(graph)
    groups = []
    for param in params:
        pruned_params = []
        param = graph.var(param)
61

62 63 64 65 66 67 68 69 70 71 72
        target_op = param.outputs()[0]
        if target_op.type() == 'conditional_block':
            for op in param.outputs():
                if op.type() in PRUNE_WORKER._module_dict.keys():
                    cls = PRUNE_WORKER.get(op.type())
                    walker = cls(op,
                                 pruned_params=pruned_params,
                                 visited=visited)
                    break
        else:
            cls = PRUNE_WORKER.get(target_op.type())
W
whs 已提交
73 74 75 76 77
            if cls is None:
                _logger.info("No walker for operator: {}".format(target_op.type(
                )))
                groups.append(pruned_params)
                continue
78 79 80 81
            walker = cls(target_op,
                         pruned_params=pruned_params,
                         visited=visited)

82
        walker.prune(param, pruned_axis=0, pruned_idx=[])
83
        groups.append(pruned_params)
84 85 86 87 88
    visited = set()
    uniq_groups = []
    for group in groups:
        repeat_group = False
        simple_group = []
W
whs 已提交
89
        for param, axis, pruned_idx in group:
90 91 92 93 94 95
            param = param.name()
            if axis == 0:
                if param in visited:
                    repeat_group = True
                else:
                    visited.add(param)
W
whs 已提交
96
            simple_group.append((param, axis, pruned_idx))
97 98 99
        if not repeat_group:
            uniq_groups.append(simple_group)
    return uniq_groups