sensitive.py 3.6 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wanghaoshuang 已提交
15 16
import sys
import os
17
import logging
W
wanghaoshuang 已提交
18
import pickle
W
wanghaoshuang 已提交
19 20
import numpy as np
from ..core import GraphWrapper
21
from ..common import get_logger
W
wanghaoshuang 已提交
22
from ..prune import Pruner
23 24

_logger = get_logger(__name__, level=logging.INFO)
W
wanghaoshuang 已提交
25 26 27 28 29 30

__all__ = ["sensitivity"]


def sensitivity(program,
                scope,
W
wanghaoshuang 已提交
31
                place,
W
wanghaoshuang 已提交
32 33
                param_names,
                eval_func,
W
wanghaoshuang 已提交
34 35
                sensitivities_file=None,
                step_size=0.2):
W
wanghaoshuang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

    graph = GraphWrapper(program)
    sensitivities = _load_sensitivities(sensitivities_file)

    for name in param_names:
        if name not in sensitivities:
            size = graph.var(name).shape()[0]
            sensitivities[name] = {
                'pruned_percent': [],
                'loss': [],
                'size': size
            }
    baseline = None
    for name in sensitivities:
        ratio = step_size
        while ratio < 1:
            ratio = round(ratio, 2)
W
wanghaoshuang 已提交
53
            if ratio in sensitivities[name]['pruned_percent']:
W
wanghaoshuang 已提交
54 55 56 57
                _logger.debug('{}, {} has computed.'.format(name, ratio))
                ratio += step_size
                continue
            if baseline is None:
W
wanghaoshuang 已提交
58
                baseline = eval_func(graph.program, scope)
W
wanghaoshuang 已提交
59 60 61 62 63 64 65 66 67 68 69 70

            param_backup = {}
            pruner = Pruner()
            pruned_program = pruner.prune(
                program=graph.program,
                scope=scope,
                params=[name],
                ratios=[ratio],
                place=place,
                lazy=True,
                only_graph=False,
                param_backup=param_backup)
W
wanghaoshuang 已提交
71
            pruned_metric = eval_func(pruned_program, scope)
W
wanghaoshuang 已提交
72 73 74 75 76 77 78 79 80 81 82 83
            loss = (baseline - pruned_metric) / baseline
            _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
                                                                loss))
            sensitivities[name]['pruned_percent'].append(ratio)
            sensitivities[name]['loss'].append(loss)
            _save_sensitivities(sensitivities, sensitivities_file)

            # restore pruned parameters
            for param_name in param_backup.keys():
                param_t = scope.find_var(param_name).get_tensor()
                param_t.set(param_backup[param_name], place)
            ratio += step_size
W
wanghaoshuang 已提交
84
        return sensitivities
W
wanghaoshuang 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103


def _load_sensitivities(sensitivities_file):
    """
    Load sensitivities from file.
    """
    sensitivities = {}
    if sensitivities_file and os.path.exists(sensitivities_file):
        with open(sensitivities_file, 'rb') as f:
            if sys.version_info < (3, 0):
                sensitivities = pickle.load(f)
            else:
                sensitivities = pickle.load(f, encoding='bytes')

    for param in sensitivities:
        sensitivities[param]['pruned_percent'] = [
            round(p, 2) for p in sensitivities[param]['pruned_percent']
        ]
    return sensitivities
W
wanghaoshuang 已提交
104 105 106 107 108 109 110 111


def _save_sensitivities(sensitivities, sensitivities_file):
    """
        Save sensitivities into file.
        """
    with open(sensitivities_file, 'wb') as f:
        pickle.dump(sensitivities, f)