sensitive.py 4.0 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wanghaoshuang 已提交
15 16
import sys
import os
17
import logging
W
wanghaoshuang 已提交
18
import pickle
W
wanghaoshuang 已提交
19
import numpy as np
W
wanghaoshuang 已提交
20
import paddle.fluid as fluid
W
wanghaoshuang 已提交
21
from ..core import GraphWrapper
22
from ..common import get_logger
W
wanghaoshuang 已提交
23
from ..prune import Pruner
24 25

_logger = get_logger(__name__, level=logging.INFO)
W
wanghaoshuang 已提交
26 27 28 29 30

__all__ = ["sensitivity"]


def sensitivity(program,
W
wanghaoshuang 已提交
31
                place,
W
wanghaoshuang 已提交
32 33
                param_names,
                eval_func,
W
wanghaoshuang 已提交
34
                sensitivities_file=None,
35 36
                step_size=0.2,
                max_pruned_times=None):
W
wanghaoshuang 已提交
37
    scope = fluid.global_scope()
W
wanghaoshuang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51
    graph = GraphWrapper(program)
    sensitivities = _load_sensitivities(sensitivities_file)

    for name in param_names:
        if name not in sensitivities:
            size = graph.var(name).shape()[0]
            sensitivities[name] = {
                'pruned_percent': [],
                'loss': [],
                'size': size
            }
    baseline = None
    for name in sensitivities:
        ratio = step_size
52
        pruned_times = 0
W
wanghaoshuang 已提交
53
        while ratio < 1:
54 55
            if max_pruned_times is not None and pruned_times >= max_pruned_times:
                break
W
wanghaoshuang 已提交
56
            ratio = round(ratio, 2)
W
wanghaoshuang 已提交
57
            if ratio in sensitivities[name]['pruned_percent']:
W
wanghaoshuang 已提交
58 59
                _logger.debug('{}, {} has computed.'.format(name, ratio))
                ratio += step_size
W
wanghaoshuang 已提交
60
                pruned_times += 1
W
wanghaoshuang 已提交
61 62
                continue
            if baseline is None:
W
wanghaoshuang 已提交
63
                baseline = eval_func(graph.program)
W
wanghaoshuang 已提交
64 65 66

            param_backup = {}
            pruner = Pruner()
W
wanghaoshuang 已提交
67 68
            _logger.info("sensitive - param: {}; ratios: {}".format(name,
                                                                    ratio))
W
wanghaoshuang 已提交
69 70 71 72 73 74 75 76 77
            pruned_program = pruner.prune(
                program=graph.program,
                scope=scope,
                params=[name],
                ratios=[ratio],
                place=place,
                lazy=True,
                only_graph=False,
                param_backup=param_backup)
W
wanghaoshuang 已提交
78
            pruned_metric = eval_func(pruned_program)
W
wanghaoshuang 已提交
79 80 81 82 83 84 85 86 87 88 89 90
            loss = (baseline - pruned_metric) / baseline
            _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
                                                                loss))
            sensitivities[name]['pruned_percent'].append(ratio)
            sensitivities[name]['loss'].append(loss)
            _save_sensitivities(sensitivities, sensitivities_file)

            # restore pruned parameters
            for param_name in param_backup.keys():
                param_t = scope.find_var(param_name).get_tensor()
                param_t.set(param_backup[param_name], place)
            ratio += step_size
91
            pruned_times += 1
W
wanghaoshuang 已提交
92
    return sensitivities
W
wanghaoshuang 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111


def _load_sensitivities(sensitivities_file):
    """
    Load sensitivities from file.
    """
    sensitivities = {}
    if sensitivities_file and os.path.exists(sensitivities_file):
        with open(sensitivities_file, 'rb') as f:
            if sys.version_info < (3, 0):
                sensitivities = pickle.load(f)
            else:
                sensitivities = pickle.load(f, encoding='bytes')

    for param in sensitivities:
        sensitivities[param]['pruned_percent'] = [
            round(p, 2) for p in sensitivities[param]['pruned_percent']
        ]
    return sensitivities
W
wanghaoshuang 已提交
112 113 114 115 116 117 118 119


def _save_sensitivities(sensitivities, sensitivities_file):
    """
        Save sensitivities into file.
        """
    with open(sensitivities_file, 'wb') as f:
        pickle.dump(sensitivities, f)