pruner.py 2.0 KB
Newer Older
W
whs 已提交
1 2 3 4 5
import os
import pickle
import numpy as np
import logging
from .pruning_plan import PruningPlan
6
from paddleslim.common import get_logger
W
whs 已提交
7 8 9 10 11 12 13 14 15 16 17 18

__all__ = ["Pruner"]

_logger = get_logger(__name__, level=logging.INFO)


class Pruner(object):
    """
    Pruner used to resize or mask dimensions of variables.
    Args:
        model(paddle.nn.Layer): The target model to be pruned.
        input_shape(list<int>): The input shape of model. It is used to trace the graph of the model.
19
        opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
W
whs 已提交
20 21
    """

22
    def __init__(self, model, inputs, opt=None):
W
whs 已提交
23
        self.model = model
W
whs 已提交
24
        self.inputs = inputs
W
whs 已提交
25 26 27 28
        self._var_shapes = {}
        for var in model.parameters():
            self._var_shapes[var.name] = var.shape
        self.plan = None
29
        self.opt = opt
W
whs 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42

    def status(self, data=None, eval_func=None, status_file=None):
        raise NotImplemented("status is not implemented")

    def prune_var(self, var_name, axis, pruned_ratio, apply="impretive"):
        raise NotImplemented("prune_var is not implemented")

    def prune_vars(self, ratios, axis, apply="impretive"):
        """
        Pruning variables by given ratios.
        Args:
            ratios(dict<str, float>): The key is the name of variable to be pruned and the
                                      value is the pruned ratio.
W
whs 已提交
43
            axis(int): The dimension to be pruned on.
W
whs 已提交
44 45 46 47

        Returns:
            plan(PruningPlan): The pruning plan.
        """
W
whs 已提交
48
        axis = axis[0] if isinstance(axis, list) else axis
W
whs 已提交
49 50 51 52 53 54 55 56
        global_plan = PruningPlan(self.model.full_name)
        for var, ratio in ratios.items():
            if not global_plan.contains(var, axis):
                plan = self.prune_var(var, axis, ratio, apply=None)
                global_plan.extend(plan)
        if apply == "lazy":
            global_plan.apply(self.model, lazy=True)
        elif apply == "impretive":
57
            global_plan.apply(self.model, lazy=False, opt=self.opt)
W
whs 已提交
58
        self.plan = global_plan
W
whs 已提交
59
        return global_plan