pruner.py 9.2 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import logging
16
import sys
W
wanghaoshuang 已提交
17
import numpy as np
18
from functools import reduce
W
wanghaoshuang 已提交
19
import paddle.fluid as fluid
20 21
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
W
whs 已提交
22
from .prune_walker import conv2d as conv2d_walker
23
from ..common import get_logger
W
wanghaoshuang 已提交
24

25
__all__ = ["Pruner"]
W
wanghaoshuang 已提交
26

27 28
_logger = get_logger(__name__, level=logging.INFO)

W
wanghaoshuang 已提交
29 30

class Pruner():
31 32 33 34 35 36 37
    """The pruner used to prune channels of convolution.

    Args:
        criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently.

    """

W
wanghaoshuang 已提交
38 39 40 41 42 43 44 45 46 47 48
    def __init__(self, criterion="l1_norm"):
        self.criterion = criterion

    def prune(self,
              program,
              scope,
              params,
              ratios,
              place=None,
              lazy=False,
              only_graph=False,
W
wanghaoshuang 已提交
49 50
              param_backup=False,
              param_shape_backup=False):
51 52
        """Pruning the given parameters.

W
wanghaoshuang 已提交
53
        Args:
54

W
wanghaoshuang 已提交
55 56 57 58 59 60 61 62 63
            program(fluid.Program): The program to be pruned.
            scope(fluid.Scope): The scope storing paramaters to be pruned.
            params(list<str>): A list of parameter names to be pruned.
            ratios(list<float>): A list of ratios to be used to pruning parameters.
            place(fluid.Place): The device place of filter parameters. Defalut: None.
            lazy(bool): True means setting the pruned elements to zero.
                        False means cutting down the pruned elements. Default: False.
            only_graph(bool): True means only modifying the graph.
                              False means modifying graph and variables in scope. Default: False.
W
wanghaoshuang 已提交
64 65
            param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False.
            param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False.
66

W
wanghaoshuang 已提交
67
        Returns:
68
            tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
W
wanghaoshuang 已提交
69 70 71 72
        """

        self.pruned_list = []
        graph = GraphWrapper(program.clone())
W
wanghaoshuang 已提交
73 74
        param_backup = {} if param_backup else None
        param_shape_backup = {} if param_shape_backup else None
W
wanghaoshuang 已提交
75

W
wanghaoshuang 已提交
76
        visited = {}
W
whs 已提交
77
        pruned_params = []
W
wanghaoshuang 已提交
78
        for param, ratio in zip(params, ratios):
W
whs 已提交
79 80 81 82 83
            if only_graph:
                param_v = graph.var(param)
                pruned_num = int(round(param_v.shape()[0] * ratio))
                pruned_idx = [0] * pruned_num
            else:
84 85
                pruned_idx = self._cal_pruned_idx(
                    graph, scope, param, ratio, axis=0)
W
wanghaoshuang 已提交
86
            param = graph.var(param)
W
whs 已提交
87
            conv_op = param.outputs()[0]
W
whs 已提交
88 89
            walker = conv2d_walker(
                conv_op, pruned_params=pruned_params, visited=visited)
W
whs 已提交
90 91 92 93 94 95 96 97 98 99 100 101
            walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)

        merge_pruned_params = {}
        for param, pruned_axis, pruned_idx in pruned_params:
            if param.name() not in merge_pruned_params:
                merge_pruned_params[param.name()] = {}
            if pruned_axis not in merge_pruned_params[param.name()]:
                merge_pruned_params[param.name()][pruned_axis] = []
            merge_pruned_params[param.name()][pruned_axis].append(pruned_idx)

        for param_name in merge_pruned_params:
            for pruned_axis in merge_pruned_params[param_name]:
W
whs 已提交
102 103
                pruned_idx = np.concatenate(merge_pruned_params[param_name][
                    pruned_axis])
W
whs 已提交
104
                param = graph.var(param_name)
W
whs 已提交
105 106 107 108 109 110 111 112 113
                if not lazy:
                    _logger.debug("{}\t{}\t{}".format(param.name(
                    ), pruned_axis, len(pruned_idx)))
                    if param_shape_backup is not None:
                        origin_shape = copy.deepcopy(param.shape())
                        param_shape_backup[param.name()] = origin_shape
                    new_shape = list(param.shape())
                    new_shape[pruned_axis] -= len(pruned_idx)
                    param.set_shape(new_shape)
W
whs 已提交
114 115
                if not only_graph:
                    param_t = scope.find_var(param.name()).get_tensor()
W
whs 已提交
116 117 118 119
                    if param_backup is not None and (
                            param.name() not in param_backup):
                        param_backup[param.name()] = copy.deepcopy(
                            np.array(param_t))
W
whs 已提交
120 121 122 123 124 125 126
                    try:
                        pruned_param = self._prune_tensor(
                            np.array(param_t),
                            pruned_idx,
                            pruned_axis=pruned_axis,
                            lazy=lazy)
                    except IndexError as e:
W
whs 已提交
127 128 129
                        _logger.error("Pruning {}, but get [{}]".format(
                            param.name(), e))

W
whs 已提交
130
                    param_t.set(pruned_param, place)
W
whs 已提交
131
        graph.update_groups_of_conv()
132
        graph.infer_shape()
W
whs 已提交
133
        return graph.program, param_backup, param_shape_backup
W
wanghaoshuang 已提交
134

135
    def _cal_pruned_idx(self, graph, scope, param, ratio, axis):
W
wanghaoshuang 已提交
136 137
        """
        Calculate the index to be pruned on axis by given pruning ratio.
138

W
wanghaoshuang 已提交
139 140 141 142 143 144 145
        Args:
            name(str): The name of parameter to be pruned.
            param(np.array): The data of parameter to be pruned.
            ratio(float): The ratio to be pruned.
            axis(int): The axis to be used for pruning given parameter.
                       If it is None, the value in self.pruning_axis will be used.
                       default: None.
146

W
wanghaoshuang 已提交
147 148 149 150
        Returns:
            list<int>: The indexes to be pruned on axis.
        """
        if self.criterion == 'l1_norm':
151 152 153 154 155
            param_t = np.array(scope.find_var(param).get_tensor())
            prune_num = int(round(param_t.shape[axis] * ratio))
            reduce_dims = [i for i in range(len(param_t.shape)) if i != axis]
            criterions = np.sum(np.abs(param_t), axis=tuple(reduce_dims))
            pruned_idx = criterions.argsort()[:prune_num]
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        elif self.criterion == 'geometry_median':
            param_t = np.array(scope.find_var(param).get_tensor())
            prune_num = int(round(param_t.shape[axis] * ratio))

            def get_distance_sum(param, out_idx):
                w = param.view()
                reduce_dims = reduce(lambda x, y: x * y, param.shape[1:])
                w.shape = param.shape[0], reduce_dims
                selected_filter = np.tile(w[out_idx], (w.shape[0], 1))
                x = w - selected_filter
                x = np.sqrt(np.sum(x * x, -1))
                return x.sum()

            dist_sum_list = []
            for out_i in range(param_t.shape[0]):
                dist_sum = get_distance_sum(param_t, out_i)
                dist_sum_list.append((dist_sum, out_i))
            min_gm_filters = sorted(
                dist_sum_list, key=lambda x: x[0])[:prune_num]
            pruned_idx = [x[1] for x in min_gm_filters]

177 178 179 180 181 182 183 184 185 186 187 188 189 190
        elif self.criterion == "batch_norm_scale":
            param_var = graph.var(param)
            conv_op = param_var.outputs()[0]
            conv_output = conv_op.outputs("Output")[0]
            bn_op = conv_output.outputs()[0]
            if bn_op is not None:
                bn_scale_param = bn_op.inputs("Scale")[0].name()
                bn_scale_np = np.array(
                    scope.find_var(bn_scale_param).get_tensor())
                prune_num = int(round(bn_scale_np.shape[axis] * ratio))
                pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num]
            else:
                raise SystemExit(
                    "Can't find BatchNorm op after Conv op in Network.")
W
wanghaoshuang 已提交
191 192 193 194 195
        return pruned_idx

    def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
        """
        Pruning a array by indexes on given axis.
196

W
wanghaoshuang 已提交
197 198 199 200 201 202 203
        Args:
            tensor(numpy.array): The target array to be pruned.
            pruned_idx(list<int>): The indexes to be pruned.
            pruned_axis(int): The axis of given array to be pruned on. 
            lazy(bool): True means setting the pruned elements to zero.
                        False means remove the pruned elements from memory.
                        default: False.
204

W
wanghaoshuang 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
        Returns:
            numpy.array: The pruned array.
        """
        mask = np.zeros(tensor.shape[pruned_axis], dtype=bool)
        mask[pruned_idx] = True

        def func(data):
            return data[~mask]

        def lazy_func(data):
            data[mask] = 0
            return data

        if lazy:
            return np.apply_along_axis(lazy_func, pruned_axis, tensor)
        else:
            return np.apply_along_axis(func, pruned_axis, tensor)