sensitive.py 8.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wanghaoshuang 已提交
15 16
import sys
import os
17
import logging
W
wanghaoshuang 已提交
18
import pickle
W
wanghaoshuang 已提交
19
import numpy as np
W
wanghaoshuang 已提交
20
import paddle.fluid as fluid
W
wanghaoshuang 已提交
21
from ..core import GraphWrapper
22
from ..common import get_logger
23
from ..analysis import flops
W
wanghaoshuang 已提交
24
from ..prune import Pruner
25 26

_logger = get_logger(__name__, level=logging.INFO)
W
wanghaoshuang 已提交
27

W
wanghaoshuang 已提交
28 29 30 31
__all__ = [
    "sensitivity", "flops_sensitivity", "load_sensitivities",
    "merge_sensitive", "get_ratios_by_loss"
]
W
wanghaoshuang 已提交
32 33 34


def sensitivity(program,
W
wanghaoshuang 已提交
35
                place,
W
wanghaoshuang 已提交
36 37
                param_names,
                eval_func,
W
wanghaoshuang 已提交
38
                sensitivities_file=None,
39
                pruned_ratios=None):
W
wanghaoshuang 已提交
40
    scope = fluid.global_scope()
W
wanghaoshuang 已提交
41
    graph = GraphWrapper(program)
W
wanghaoshuang 已提交
42
    sensitivities = load_sensitivities(sensitivities_file)
W
wanghaoshuang 已提交
43

44 45 46
    if pruned_ratios is None:
        pruned_ratios = np.arange(0.1, 1, step=0.1)

W
wanghaoshuang 已提交
47 48
    for name in param_names:
        if name not in sensitivities:
W
wanghaoshuang 已提交
49
            sensitivities[name] = {}
W
wanghaoshuang 已提交
50 51
    baseline = None
    for name in sensitivities:
52
        for ratio in pruned_ratios:
W
wanghaoshuang 已提交
53
            if ratio in sensitivities[name]:
W
wanghaoshuang 已提交
54 55 56
                _logger.debug('{}, {} has computed.'.format(name, ratio))
                continue
            if baseline is None:
W
wanghaoshuang 已提交
57
                baseline = eval_func(graph.program)
W
wanghaoshuang 已提交
58 59

            pruner = Pruner()
W
wanghaoshuang 已提交
60 61
            _logger.info("sensitive - param: {}; ratios: {}".format(name,
                                                                    ratio))
W
wanghaoshuang 已提交
62
            pruned_program, param_backup, _ = pruner.prune(
W
wanghaoshuang 已提交
63 64 65 66 67 68 69
                program=graph.program,
                scope=scope,
                params=[name],
                ratios=[ratio],
                place=place,
                lazy=True,
                only_graph=False,
W
wanghaoshuang 已提交
70
                param_backup=True)
W
wanghaoshuang 已提交
71
            pruned_metric = eval_func(pruned_program)
W
wanghaoshuang 已提交
72 73 74
            loss = (baseline - pruned_metric) / baseline
            _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
                                                                loss))
W
wanghaoshuang 已提交
75

W
wanghaoshuang 已提交
76
            sensitivities[name][ratio] = loss
W
wanghaoshuang 已提交
77

W
wanghaoshuang 已提交
78 79 80 81 82 83
            _save_sensitivities(sensitivities, sensitivities_file)

            # restore pruned parameters
            for param_name in param_backup.keys():
                param_t = scope.find_var(param_name).get_tensor()
                param_t.set(param_backup[param_name], place)
W
wanghaoshuang 已提交
84
    return sensitivities
W
wanghaoshuang 已提交
85 86


87 88 89 90 91 92 93 94 95 96 97
def flops_sensitivity(program,
                      place,
                      param_names,
                      eval_func,
                      sensitivities_file=None,
                      pruned_flops_rate=0.1):

    assert (1.0 / len(param_names) > pruned_flops_rate)

    scope = fluid.global_scope()
    graph = GraphWrapper(program)
W
wanghaoshuang 已提交
98
    sensitivities = load_sensitivities(sensitivities_file)
99 100 101

    for name in param_names:
        if name not in sensitivities:
W
wanghaoshuang 已提交
102
            sensitivities[name] = {}
103 104 105 106 107 108 109
    base_flops = flops(program)
    target_pruned_flops = base_flops * pruned_flops_rate

    pruner = Pruner()
    baseline = None
    for name in sensitivities:

W
wanghaoshuang 已提交
110
        pruned_program, _, _ = pruner.prune(
111 112 113 114 115 116 117 118
            program=graph.program,
            scope=None,
            params=[name],
            ratios=[0.5],
            place=None,
            lazy=False,
            only_graph=True)
        param_flops = (base_flops - flops(pruned_program)) * 2
W
wanghaoshuang 已提交
119
        channel_size = graph.var(name).shape()[0]
120
        pruned_ratio = target_pruned_flops / float(param_flops)
W
wanghaoshuang 已提交
121
        pruned_ratio = round(pruned_ratio, 3)
122 123 124
        pruned_size = round(pruned_ratio * channel_size)
        pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio

W
wanghaoshuang 已提交
125 126 127 128
        if len(sensitivities[name].keys()) > 0:
            _logger.debug(
                '{} exist; pruned ratio: {}; excepted ratio: {}'.format(
                    name, sensitivities[name].keys(), pruned_ratio))
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
            continue
        if baseline is None:
            baseline = eval_func(graph.program)
        param_backup = {}
        pruner = Pruner()
        _logger.info("sensitive - param: {}; ratios: {}".format(name,
                                                                pruned_ratio))
        loss = 1
        if pruned_ratio < 1:
            pruned_program = pruner.prune(
                program=graph.program,
                scope=scope,
                params=[name],
                ratios=[pruned_ratio],
                place=place,
                lazy=True,
                only_graph=False,
                param_backup=param_backup)
            pruned_metric = eval_func(pruned_program)
            loss = (baseline - pruned_metric) / baseline
        _logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio,
                                                            loss))
W
wanghaoshuang 已提交
151
        sensitivities[name][pruned_ratio] = loss
152 153 154 155 156 157 158 159 160
        _save_sensitivities(sensitivities, sensitivities_file)

        # restore pruned parameters
        for param_name in param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(param_backup[param_name], place)
    return sensitivities


W
wanghaoshuang 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
def merge_sensitive(sensitivities):
    """
    Merge sensitivities.
    Args:
      sensitivities(list<dict> | list<str>): The sensitivities to be merged. It cann be a list of sensitivities files or dict.

    Returns:
      sensitivities(dict): A dict with sensitivities.
    """
    assert len(sensitivities) > 0
    if not isinstance(sensitivities[0], dict):
        sensitivities = [pickle.load(open(sen, 'r')) for sen in sensitivities]

    new_sensitivities = {}
    for sen in sensitivities:
        for param, losses in sen.items():
            if param not in new_sensitivities:
                new_sensitivities[param] = {}
            for percent, loss in losses.items():
                new_sensitivities[param][percent] = loss
    return new_sensitivities


def load_sensitivities(sensitivities_file):
W
wanghaoshuang 已提交
185 186 187 188 189 190 191 192 193 194 195
    """
    Load sensitivities from file.
    """
    sensitivities = {}
    if sensitivities_file and os.path.exists(sensitivities_file):
        with open(sensitivities_file, 'rb') as f:
            if sys.version_info < (3, 0):
                sensitivities = pickle.load(f)
            else:
                sensitivities = pickle.load(f, encoding='bytes')
    return sensitivities
W
wanghaoshuang 已提交
196 197 198 199


def _save_sensitivities(sensitivities, sensitivities_file):
    """
W
wanghaoshuang 已提交
200 201
    Save sensitivities into file.
    """
W
wanghaoshuang 已提交
202 203
    with open(sensitivities_file, 'wb') as f:
        pickle.dump(sensitivities, f)
W
wanghaoshuang 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223


def get_ratios_by_loss(sensitivities, loss):
    """
    Get the max ratio of each parameter. The loss of accuracy must be less than given `loss`
    when the single parameter was pruned by the max ratio. 
    
    Args:
      
      sensitivities(dict): The sensitivities used to generate a group of pruning ratios. The key of dict
                           is name of parameters to be pruned. The value of dict is a list of tuple with
                           format `(pruned_ratio, accuracy_loss)`.
      loss(float): The threshold of accuracy loss.

    Returns:

      ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
    """
    ratios = {}
    for param, losses in sensitivities.items():
W
wanghaoshuang 已提交
224
        losses = losses.items()
W
wanghaoshuang 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        losses.sort()
        for i in range(len(losses))[::-1]:
            if losses[i][1] <= loss:
                if i == (len(losses) - 1):
                    ratios[param] = losses[i][0]
                else:
                    r0, l0 = losses[i]
                    r1, l1 = losses[i + 1]
                    d0 = loss - l0
                    d1 = l1 - loss

                    ratio = r0 + (loss - l0) * (r1 - r0) / (l1 - l0)
                    ratios[param] = ratio
                    if ratio > 1:
                        print losses, ratio, (r1 - r0) / (l1 - l0), i

                break
    return ratios