interpretation.py 1.8 KB
Newer Older
S
sunyanfang01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

S
sunyanfang01 已提交
15
from .interpretation_algorithms import CAM, LIME, NormLIME
S
sunyanfang01 已提交
16
from .normlime_base import precompute_normlime_weights
S
sunyanfang01 已提交
17 18


S
sunyanfang01 已提交
19
class Interpretation(object):
S
SunAhong1993 已提交
20
    """
S
sunyanfang01 已提交
21
    Base class for all interpretation algorithms.
S
SunAhong1993 已提交
22
    """
S
sunyanfang01 已提交
23
    def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
S
SunAhong1993 已提交
24 25 26 27 28 29
        supported_algorithms = {
            'cam': CAM,
            'lime': LIME,
            'normlime': NormLIME
        }

S
sunyanfang01 已提交
30
        self.algorithm_name = interpretation_algorithm_name.lower()
S
SunAhong1993 已提交
31 32 33
        assert self.algorithm_name in supported_algorithms.keys()
        self.predict_fn = predict_fn

S
sunyanfang01 已提交
34 35
        # initialization for the interpretation algorithm.
        self.algorithm = supported_algorithms[self.algorithm_name](
S
SunAhong1993 已提交
36
            self.predict_fn, label_names, **kwargs
S
sunyanfang01 已提交
37 38
        )

S
sunyanfang01 已提交
39
    def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
S
sunyanfang01 已提交
40 41 42 43 44 45 46 47 48 49 50
        """

        Args:
            data_: data_ can be a path or numpy.ndarray.
            visualization: whether to show using matplotlib.
            save_to_disk: whether to save the figure in local disk.
            save_dir: dir to save figure if save_to_disk is True.

        Returns:

        """
S
sunyanfang01 已提交
51
        return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)