pruner.py 8.9 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import logging
16
import sys
Y
yukavio 已提交
17
import copy
W
wanghaoshuang 已提交
18
import numpy as np
19
from functools import reduce
20
from ..core import VarWrapper, OpWrapper, GraphWrapper
W
whs 已提交
21
from .collections import StaticPruningCollections
22 23
from .criterion import CRITERION
from .idx_selector import IDX_SELECTOR
24
from ..common import get_logger
W
wanghaoshuang 已提交
25

26
__all__ = ["Pruner"]
W
wanghaoshuang 已提交
27

28 29
_logger = get_logger(__name__, level=logging.INFO)

W
wanghaoshuang 已提交
30 31

class Pruner():
32 33 34
    """The pruner used to prune channels of convolution.

    Args:
35 36
        criterion(str|function): the criterion used to sort channels for pruning.
        idx_selector(str|function): 
37 38 39

    """

Y
yukavio 已提交
40
    def __init__(self, criterion="l1_norm",
41 42 43 44 45 46 47 48 49 50 51
                 idx_selector="default_idx_selector"):
        if isinstance(criterion, str):
            self.criterion = CRITERION.get(criterion)
        else:
            self.criterion = criterion
        if isinstance(idx_selector, str):
            self.idx_selector = IDX_SELECTOR.get(idx_selector)
        else:
            self.idx_selector = idx_selector

        self.pruned_weights = False
W
wanghaoshuang 已提交
52 53 54 55 56 57 58 59 60

    def prune(self,
              program,
              scope,
              params,
              ratios,
              place=None,
              lazy=False,
              only_graph=False,
W
wanghaoshuang 已提交
61 62
              param_backup=False,
              param_shape_backup=False):
63 64
        """Pruning the given parameters.

W
wanghaoshuang 已提交
65
        Args:
66

Y
yukavio 已提交
67
            program(paddle.static.Program): The program to be pruned.
68
            scope(paddle.static.Scope): The scope storing paramaters to be pruned.
W
wanghaoshuang 已提交
69 70
            params(list<str>): A list of parameter names to be pruned.
            ratios(list<float>): A list of ratios to be used to pruning parameters.
71
            place(paddle.CUDAPlace||paddle.CPUPlace): The device place of filter parameters. Defalut: None.
W
wanghaoshuang 已提交
72 73 74 75
            lazy(bool): True means setting the pruned elements to zero.
                        False means cutting down the pruned elements. Default: False.
            only_graph(bool): True means only modifying the graph.
                              False means modifying graph and variables in scope. Default: False.
W
wanghaoshuang 已提交
76 77
            param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False.
            param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False.
78

W
wanghaoshuang 已提交
79
        Returns:
80
            tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
W
wanghaoshuang 已提交
81 82 83
        """
        self.pruned_list = []
        graph = GraphWrapper(program.clone())
W
wanghaoshuang 已提交
84 85
        param_backup = {} if param_backup else None
        param_shape_backup = {} if param_shape_backup else None
W
wanghaoshuang 已提交
86

W
whs 已提交
87
        pruned_params = []
W
whs 已提交
88 89 90 91 92 93
        collections = StaticPruningCollections(params, graph)
        ratios = dict(zip(params, ratios))
        values = {}
        for _collection in collections:
            for _var_name in _collection.variables():
                var = scope.find_var(_var_name)
W
whs 已提交
94
                if var is not None:
W
whs 已提交
95 96
                    value = np.array(var.get_tensor())
                    values[_var_name] = value
W
whs 已提交
97

W
whs 已提交
98 99 100 101 102 103
        for _collection in collections:
            scores = self.criterion(_collection, values, graph)
            idx = self.idx_selector(_collection, scores,
                                    ratios)  # name, axis, idx, transform
            idx = self._transform(idx)
            pruned_params.extend(idx)
W
whs 已提交
104 105 106

        merge_pruned_params = {}
        for param, pruned_axis, pruned_idx in pruned_params:
107 108 109 110 111
            if param not in merge_pruned_params:
                merge_pruned_params[param] = {}
            if pruned_axis not in merge_pruned_params[param]:
                merge_pruned_params[param][pruned_axis] = []
            merge_pruned_params[param][pruned_axis].append(pruned_idx)
W
whs 已提交
112 113
        for param_name in merge_pruned_params:
            for pruned_axis in merge_pruned_params[param_name]:
W
whs 已提交
114 115
                pruned_idx = np.concatenate(merge_pruned_params[param_name][
                    pruned_axis])
W
whs 已提交
116
                param = graph.var(param_name)
W
whs 已提交
117
                _groups = 1
W
whs 已提交
118
                if not lazy:
W
whs 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
                    # update groups of conv2d
                    if pruned_axis == 1:
                        for op in param.outputs():
                            if op.type() in ["conv2d", "depthwise_conv2d"
                                             ] and op.attr("groups") > 1:
                                _groups = op.attr("groups")
                                _filter_num = param.shape()[1]
                                new_groups = int(
                                    (_groups * _filter_num - len(pruned_idx)) /
                                    _filter_num)
                                _logger.info(
                                    f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};"
                                )
                                op.set_attr("groups", new_groups)
                    if _groups == 1:
                        origin_shape = copy.deepcopy(param.shape())
                        if param_shape_backup is not None:
                            param_shape_backup[param.name()] = origin_shape
                        new_shape = list(param.shape())
                        new_shape[pruned_axis] -= len(pruned_idx)
                        param.set_shape(new_shape)

                if not only_graph and (_groups == 1 or pruned_axis != 1):
                    _var = scope.find_var(param.name())
                    if _var is None:
                        continue
                    param_t = _var.get_tensor()
W
whs 已提交
146 147 148 149
                    if param_backup is not None and (
                            param.name() not in param_backup):
                        param_backup[param.name()] = copy.deepcopy(
                            np.array(param_t))
W
whs 已提交
150 151 152 153 154 155
                    try:
                        pruned_param = self._prune_tensor(
                            np.array(param_t),
                            pruned_idx,
                            pruned_axis=pruned_axis,
                            lazy=lazy)
W
whs 已提交
156
                        param_t.set(pruned_param, place)
W
whs 已提交
157
                    except IndexError as e:
W
whs 已提交
158 159 160 161
                        _logger.error(
                            "Pruning {} with shape {} on axis {}, but get [{}]; ".
                            format(param.name(),
                                   param_t.shape(), pruned_axis, e))
W
whs 已提交
162

163
        graph.infer_shape()
164
        self.pruned_weights = (not only_graph)
W
whs 已提交
165
        return graph.program, param_backup, param_shape_backup
W
wanghaoshuang 已提交
166

W
whs 已提交
167
    def _transform(self, items):
W
whs 已提交
168
        ret = []
W
whs 已提交
169
        for name, axis, pruned_idx, transforms in items:
W
whs 已提交
170 171 172 173
            src = pruned_idx
            for trans in transforms:
                src_start = trans['src_start']
                src_end = trans['src_end']
W
whs 已提交
174
                src_len = src_end - src_start
W
whs 已提交
175 176
                target_start = trans['target_start']
                target_end = trans['target_end']
W
whs 已提交
177
                starts = np.array(range(target_start, target_end, src_len))
W
whs 已提交
178 179 180 181
                target = []
                for idx in src:
                    if idx >= src_start and idx < src_end:
                        idx -= src_start
W
whs 已提交
182
                        target.extend(list(idx + starts))
W
whs 已提交
183 184 185 186
                src = target
            ret.append((name, axis, src))
        return ret

W
wanghaoshuang 已提交
187 188
    def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
        """
W
whs 已提交
189
        Pruning a array by indices on given axis.
190

W
wanghaoshuang 已提交
191 192
        Args:
            tensor(numpy.array): The target array to be pruned.
W
whs 已提交
193
            pruned_idx(list<int>): The indices to be pruned.
W
wanghaoshuang 已提交
194 195 196 197
            pruned_axis(int): The axis of given array to be pruned on. 
            lazy(bool): True means setting the pruned elements to zero.
                        False means remove the pruned elements from memory.
                        default: False.
198

W
wanghaoshuang 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        Returns:
            numpy.array: The pruned array.
        """
        mask = np.zeros(tensor.shape[pruned_axis], dtype=bool)
        mask[pruned_idx] = True

        def func(data):
            return data[~mask]

        def lazy_func(data):
            data[mask] = 0
            return data

        if lazy:
            return np.apply_along_axis(lazy_func, pruned_axis, tensor)
        else:
            return np.apply_along_axis(func, pruned_axis, tensor)