pruner.py 1.9 KB
Newer Older
W
whs 已提交
1 2 3 4 5
import os
import pickle
import numpy as np
import logging
from .pruning_plan import PruningPlan
6
from paddleslim.common import get_logger
W
whs 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

__all__ = ["Pruner"]

_logger = get_logger(__name__, level=logging.INFO)


class Pruner(object):
    """
    Pruner used to resize or mask dimensions of variables.
    Args:
        model(paddle.nn.Layer): The target model to be pruned.
        input_shape(list<int>): The input shape of model. It is used to trace the graph of the model.
        
    """

W
whs 已提交
22
    def __init__(self, model, inputs):
W
whs 已提交
23
        self.model = model
W
whs 已提交
24
        self.inputs = inputs
W
whs 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
        self._var_shapes = {}
        for var in model.parameters():
            self._var_shapes[var.name] = var.shape
        self.plan = None

    def status(self, data=None, eval_func=None, status_file=None):
        raise NotImplemented("status is not implemented")

    def prune_var(self, var_name, axis, pruned_ratio, apply="impretive"):
        raise NotImplemented("prune_var is not implemented")

    def prune_vars(self, ratios, axis, apply="impretive"):
        """
        Pruning variables by given ratios.
        Args:
            ratios(dict<str, float>): The key is the name of variable to be pruned and the
                                      value is the pruned ratio.
W
whs 已提交
42
            axis(int): The dimension to be pruned on.
W
whs 已提交
43 44 45 46

        Returns:
            plan(PruningPlan): The pruning plan.
        """
W
whs 已提交
47
        axis = axis[0] if isinstance(axis, list) else axis
W
whs 已提交
48 49 50 51 52 53 54 55 56
        global_plan = PruningPlan(self.model.full_name)
        for var, ratio in ratios.items():
            if not global_plan.contains(var, axis):
                plan = self.prune_var(var, axis, ratio, apply=None)
                global_plan.extend(plan)
        if apply == "lazy":
            global_plan.apply(self.model, lazy=True)
        elif apply == "impretive":
            global_plan.apply(self.model, lazy=False)
W
whs 已提交
57
        self.plan = global_plan
W
whs 已提交
58
        return global_plan