pruner.py 4.0 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16
import collections
W
whs 已提交
17 18
from .... import layers

19
__all__ = ['Pruner', 'StructurePruner']
W
whs 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33


class Pruner(object):
    """
    Base class of all pruners.
    """

    def __init__(self):
        pass

    def prune(self, param):
        pass


34
class StructurePruner(Pruner):
W
whs 已提交
35
    """
36
    Pruner used to pruning parameters by groups.
W
whs 已提交
37 38
    """

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    def __init__(self, pruning_axis, criterions):
        """
        Args:
            pruning_axis(dict): The key is the name of parameter to be pruned,
                                '*' means all the parameters.
                                The value is the axis to be used. Given a parameter
                                with shape [3, 4], the result of pruning 50% on aixs 1
                                is a parameter with shape [3, 2].
            criterions(dict): The key is the name of parameter to be pruned,
                              '*' means all the parameters.
                              The value is the criterion used to sort groups for pruning.
                              It only supports 'l1_norm' currently.
        """
        self.pruning_axis = pruning_axis
        self.criterions = criterions
W
whs 已提交
54

55
    def cal_pruned_idx(self, name, param, ratio, axis=None):
W
whs 已提交
56
        """
57
        Calculate the index to be pruned on axis by given pruning ratio.
W
whs 已提交
58
        Args:
59 60 61 62 63 64 65 66
            name(str): The name of parameter to be pruned.
            param(np.array): The data of parameter to be pruned.
            ratio(float): The ratio to be pruned.
            axis(int): The axis to be used for pruning given parameter.
                       If it is None, the value in self.pruning_axis will be used.
                       default: None.
        Returns:
            list<int>: The indexes to be pruned on axis.
W
whs 已提交
67
        """
68 69 70 71 72 73 74 75 76 77 78 79
        criterion = self.criterions[
            name] if name in self.criterions else self.criterions['*']
        if axis is None:
            assert self.pruning_axis is not None, "pruning_axis should set if axis is None."
            axis = self.pruning_axis[
                name] if name in self.pruning_axis else self.pruning_axis['*']
        prune_num = int(round(param.shape[axis] * ratio))
        reduce_dims = [i for i in range(len(param.shape)) if i != axis]
        if criterion == 'l1_norm':
            criterions = np.sum(np.abs(param), axis=tuple(reduce_dims))
        pruned_idx = criterions.argsort()[:prune_num]
        return pruned_idx
W
whs 已提交
80

81
    def prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
W
whs 已提交
82
        """
83
        Pruning a array by indexes on given axis.
W
whs 已提交
84
        Args:
85 86 87 88 89 90 91 92
            tensor(numpy.array): The target array to be pruned.
            pruned_idx(list<int>): The indexes to be pruned.
            pruned_axis(int): The axis of given array to be pruned on. 
            lazy(bool): True means setting the pruned elements to zero.
                        False means remove the pruned elements from memory.
                        default: False.
        Returns:
            numpy.array: The pruned array.
W
whs 已提交
93
        """
94 95 96 97 98 99 100 101 102 103 104 105
        mask = np.zeros(tensor.shape[pruned_axis], dtype=bool)
        mask[pruned_idx] = True

        def func(data):
            return data[~mask]

        def lazy_func(data):
            data[mask] = 0
            return data

        if lazy:
            return np.apply_along_axis(lazy_func, pruned_axis, tensor)
W
whs 已提交
106
        else:
107
            return np.apply_along_axis(func, pruned_axis, tensor)