sensitive_pruner.py 9.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
W
wanghaoshuang 已提交
16 17 18 19 20
import logging
import copy
from scipy.optimize import leastsq
import numpy as np
import paddle.fluid as fluid
W
wanghaoshuang 已提交
21
from ..common import get_logger
22
from .sensitive import sensitivity
W
wanghaoshuang 已提交
23
from .sensitive import flops_sensitivity, get_ratios_by_loss
W
wanghaoshuang 已提交
24 25
from ..analysis import flops
from .pruner import Pruner
W
wanghaoshuang 已提交
26 27 28 29 30 31 32

__all__ = ["SensitivePruner"]

_logger = get_logger(__name__, level=logging.INFO)


class SensitivePruner(object):
33 34 35 36 37 38 39 40 41 42 43 44 45
    """
    Pruner used to prune parameters iteratively according to sensitivities
    of parameters in each step.

    Args:
        place(fluid.CUDAPlace | fluid.CPUPlace): The device place where
            program execute.
        eval_func(function): A callback function used to evaluate pruned
            program. The argument of this function is pruned program.
            And it return a score of given program.
        scope(fluid.scope): The scope used to execute program.
    """

46
    def __init__(self, place, eval_func, scope=None, checkpoints=None):
W
wanghaoshuang 已提交
47 48 49 50
        self._eval_func = eval_func
        self._iter = 0
        self._place = place
        self._scope = fluid.global_scope() if scope is None else scope
W
wanghaoshuang 已提交
51
        self._pruner = Pruner()
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        self._checkpoints = checkpoints

    def save_checkpoint(self, train_program, eval_program):
        checkpoint = os.path.join(self._checkpoints, str(self._iter - 1))
        exe = fluid.Executor(self._place)
        fluid.io.save_persistables(
            exe, checkpoint, main_program=train_program, filename="__params__")

        with open(checkpoint + "/main_program", "wb") as f:
            f.write(train_program.desc.serialize_to_string())
        with open(checkpoint + "/eval_program", "wb") as f:
            f.write(eval_program.desc.serialize_to_string())

    def restore(self, checkpoints=None):

        exe = fluid.Executor(self._place)
        checkpoints = self._checkpoints if checkpoints is None else checkpoints
69
        _logger.info("check points: {}".format(checkpoints))
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        main_program = None
        eval_program = None
        if checkpoints is not None:
            cks = [dir for dir in os.listdir(checkpoints)]
            if len(cks) > 0:
                latest = max([int(ck) for ck in cks])
                latest_ck_path = os.path.join(checkpoints, str(latest))
                self._iter += 1

                with open(latest_ck_path + "/main_program", "rb") as f:
                    program_desc_str = f.read()
                main_program = fluid.Program.parse_from_string(
                    program_desc_str)

                with open(latest_ck_path + "/eval_program", "rb") as f:
                    program_desc_str = f.read()
                eval_program = fluid.Program.parse_from_string(
                    program_desc_str)

                with fluid.scope_guard(self._scope):
                    fluid.io.load_persistables(exe, latest_ck_path,
                                               main_program, "__params__")
92 93 94
                _logger.info("load checkpoint from: {}".format(latest_ck_path))
                _logger.info("flops of eval program: {}".format(
                    flops(eval_program)))
95
        return main_program, eval_program, self._iter
W
wanghaoshuang 已提交
96

97 98 99 100
    def greedy_prune(self,
                     train_program,
                     eval_program,
                     params,
101
                     pruned_flops_rate,
102 103 104 105 106
                     topk=1):

        sensitivities_file = "greedy_sensitivities_iter{}.data".format(
            self._iter)
        with fluid.scope_guard(self._scope):
107
            sensitivities = flops_sensitivity(
108 109 110 111 112
                eval_program,
                self._place,
                params,
                self._eval_func,
                sensitivities_file=sensitivities_file,
113
                pruned_flops_rate=pruned_flops_rate)
114
        _logger.info(sensitivities)
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk)

        _logger.info("Pruning: {} by {}".format(params, ratios))
        pruned_program = self._pruner.prune(
            train_program,
            self._scope,
            params,
            ratios,
            place=self._place,
            only_graph=False)
        pruned_val_program = None
        if eval_program is not None:
            pruned_val_program = self._pruner.prune(
                eval_program,
                self._scope,
                params,
                ratios,
                place=self._place,
                only_graph=True)
        self._iter += 1
        return pruned_program, pruned_val_program

W
wanghaoshuang 已提交
137
    def prune(self, train_program, eval_program, params, pruned_flops):
W
wanghaoshuang 已提交
138 139
        """
        Pruning parameters of training and evaluation network by sensitivities in current step.
140

W
wanghaoshuang 已提交
141 142 143 144 145
        Args:
            train_program(fluid.Program): The training program to be pruned.
            eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
            params(list<str>): The parameters to be pruned.
            pruned_flops(float): The ratio of FLOPS to be pruned in current step.
146 147

        Returns:
W
wanghaoshuang 已提交
148 149 150
            tuple: A tuple of pruned training program and pruned evaluation program.
        """
        _logger.info("Pruning: {}".format(params))
W
wanghaoshuang 已提交
151 152 153 154 155 156 157 158 159
        sensitivities_file = "sensitivities_iter{}.data".format(self._iter)
        with fluid.scope_guard(self._scope):
            sensitivities = sensitivity(
                eval_program,
                self._place,
                params,
                self._eval_func,
                sensitivities_file=sensitivities_file,
                step_size=0.1)
160
        _logger.info(sensitivities)
W
wanghaoshuang 已提交
161 162
        _, ratios = self.get_ratios_by_sensitive(sensitivities, pruned_flops,
                                                 eval_program)
W
wanghaoshuang 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182

        pruned_program = self._pruner.prune(
            train_program,
            self._scope,
            params,
            ratios,
            place=self._place,
            only_graph=False)
        pruned_val_program = None
        if eval_program is not None:
            pruned_val_program = self._pruner.prune(
                eval_program,
                self._scope,
                params,
                ratios,
                place=self._place,
                only_graph=True)
        self._iter += 1
        return pruned_program, pruned_val_program

183 184 185 186 187 188 189 190 191 192
    def _greedy_ratio_by_sensitive(self, sensitivities, topk=1):
        losses = {}
        percents = {}
        for param in sensitivities:
            losses[param] = sensitivities[param]['loss'][0]
            percents[param] = sensitivities[param]['pruned_percent'][0]
        topk_parms = sorted(losses, key=losses.__getitem__)[:topk]
        topk_percents = [percents[param] for param in topk_parms]
        return topk_parms, topk_percents

W
wanghaoshuang 已提交
193 194
    def get_ratios_by_sensitive(self, sensitivities, pruned_flops,
                                eval_program):
W
wanghaoshuang 已提交
195 196 197
        """
        Search a group of ratios for pruning target flops.

W
wanghaoshuang 已提交
198
        Args:
W
wanghaoshuang 已提交
199

W
wanghaoshuang 已提交
200 201 202 203 204
          sensitivities(dict): The sensitivities used to generate a group of pruning ratios. The key of dict
                               is name of parameters to be pruned. The value of dict is a list of tuple with
                               format `(pruned_ratio, accuracy_loss)`.
          pruned_flops(float): The percent of FLOPS to be pruned.
          eval_program(Program): The program whose FLOPS is considered.
W
wanghaoshuang 已提交
205

206
        Returns:
W
wanghaoshuang 已提交
207

208
          dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
W
wanghaoshuang 已提交
209
        """
W
wanghaoshuang 已提交
210 211 212 213

        min_loss = 0.
        max_loss = 0.
        # step 2: Find a group of ratios by binary searching.
W
wanghaoshuang 已提交
214
        base_flops = flops(eval_program)
W
wanghaoshuang 已提交
215
        ratios = None
W
wanghaoshuang 已提交
216 217
        max_times = 20
        while min_loss < max_loss and max_times > 0:
W
wanghaoshuang 已提交
218 219
            loss = (max_loss + min_loss) / 2
            _logger.info(
W
wanghaoshuang 已提交
220
                '-----------Try pruned ratios while acc loss={}-----------'.
W
wanghaoshuang 已提交
221
                format(loss))
W
wanghaoshuang 已提交
222
            ratios = self.get_ratios_by_loss(sensitivities, loss)
W
wanghaoshuang 已提交
223
            _logger.info('Pruned ratios={}'.format(
W
wanghaoshuang 已提交
224
                [round(ratio, 3) for ratio in ratios.values()]))
W
wanghaoshuang 已提交
225
            pruned_program = self._pruner.prune(
W
wanghaoshuang 已提交
226 227
                eval_program,
                None,  # scope
W
wanghaoshuang 已提交
228 229
                ratios.keys(),
                ratios.values(),
W
wanghaoshuang 已提交
230 231
                None,  # place
                only_graph=True)
W
wanghaoshuang 已提交
232 233
            pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
            _logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))
W
wanghaoshuang 已提交
234

W
wanghaoshuang 已提交
235
            # Check whether current ratios is enough
W
wanghaoshuang 已提交
236
            if abs(pruned_ratio - pruned_flops) < 0.015:
W
wanghaoshuang 已提交
237
                break
W
wanghaoshuang 已提交
238
            if pruned_ratio > pruned_flops:
W
wanghaoshuang 已提交
239 240 241
                max_loss = loss
            else:
                min_loss = loss
W
wanghaoshuang 已提交
242
            max_times -= 1
W
wanghaoshuang 已提交
243
        return ratios