fpgm_pruner.py 1.4 KB
Newer Older
W
whs 已提交
1 2 3
import logging
import numpy as np
import paddle
4
from paddleslim.common import get_logger
W
whs 已提交
5 6 7 8 9 10 11 12 13 14
from .var_group import *
from .pruning_plan import *
from .filter_pruner import FilterPruner

__all__ = ['FPGMFilterPruner']

_logger = get_logger(__name__, logging.INFO)


class FPGMFilterPruner(FilterPruner):
W
whs 已提交
15 16
    def __init__(self, model, inputs, sen_file=None):
        super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file)
W
whs 已提交
17 18

    def cal_mask(self, var_name, pruned_ratio, group):
W
whs 已提交
19 20 21 22
        for _item in group[var_name]:
            if _item['pruned_dims'] == [0]:
                value = _item['value']
                pruned_dims = _item['pruned_dims']
W
whs 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
        dist_sum_list = []
        for out_i in range(value.shape[0]):
            dist_sum = self.get_distance_sum(value, out_i)
            dist_sum_list.append(dist_sum)
        scores = np.array(dist_sum_list)

        sorted_idx = scores.argsort()
        pruned_num = int(round(len(sorted_idx) * pruned_ratio))
        pruned_idx = sorted_idx[:pruned_num]
        mask_shape = [value.shape[i] for i in pruned_dims]
        mask = np.ones(mask_shape, dtype="int32")
        mask[pruned_idx] = 0
        return mask

    def get_distance_sum(self, value, out_idx):
        w = value.view()
        w.shape = value.shape[0], np.product(value.shape[1:])
        selected_filter = np.tile(w[out_idx], (w.shape[0], 1))
        x = w - selected_filter
        x = np.sqrt(np.sum(x * x, -1))
        return x.sum()