l1norm_pruner.py 1.6 KB
Newer Older
W
whs 已提交
1 2 3
import logging
import numpy as np
import paddle
4
from paddleslim.common import get_logger
W
whs 已提交
5 6 7 8 9 10 11 12 13 14
from .var_group import *
from .pruning_plan import *
from .filter_pruner import FilterPruner

__all__ = ['L1NormFilterPruner']

_logger = get_logger(__name__, logging.INFO)


class L1NormFilterPruner(FilterPruner):
W
whs 已提交
15
    def __init__(self, model, inputs, sen_file=None):
W
whs 已提交
16
        super(L1NormFilterPruner, self).__init__(
W
whs 已提交
17
            model, inputs, sen_file=sen_file)
W
whs 已提交
18

W
whs 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32
    def cal_mask(self, pruned_ratio, collection):
        var_name = collection.master_name
        pruned_axis = collection.master_axis
        value = collection.values[var_name]
        groups = 1
        for _detail in collection.all_pruning_details():
            assert (isinstance(_detail.axis, int))
            if _detail.axis == 1:
                _groups = _detail.op.attr('groups')
                if _groups is not None and _groups > 1:
                    groups = _groups
                    break

        reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
W
whs 已提交
33
        l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
W
whs 已提交
34 35 36 37
        if groups > 1:
            l1norm = l1norm.reshape([groups, -1])
            l1norm = np.mean(l1norm, axis=1)

W
whs 已提交
38 39 40
        sorted_idx = l1norm.argsort()
        pruned_num = int(round(len(sorted_idx) * pruned_ratio))
        pruned_idx = sorted_idx[:pruned_num]
W
whs 已提交
41 42

        mask_shape = [value.shape[pruned_axis]]
W
whs 已提交
43
        mask = np.ones(mask_shape, dtype="int32")
W
whs 已提交
44 45
        if groups > 1:
            mask = mask.reshape([groups, -1])
W
whs 已提交
46
        mask[pruned_idx] = 0
W
whs 已提交
47
        return mask.reshape(mask_shape)