visualize.py 7.8 KB
Newer Older
F
FlyingQianMM 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
sunyanfang01 已提交
14 15 16 17 18 19

import os
import cv2
import copy
import os.path as osp
import numpy as np
S
sunyanfang01 已提交
20
import paddlex as pdx
S
sunyanfang01 已提交
21
from .interpretation_predict import interpretation_predict
S
sunyanfang01 已提交
22
from .core.interpretation import Interpretation
S
seven 已提交
23
from .core.normlime_base import precompute_global_classifier
S
sunyanfang01 已提交
24
from .core._session_preparation import gen_user_home
S
seven 已提交
25 26 27 28 29


def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
    """使用LIME算法将模型预测结果的可解释性可视化。

S
sunyanfang01 已提交
30 31 32
    LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
    在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
    和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
S
seven 已提交
33 34
    得到每个输入维度的权重,以此来解释模型。

S
sunyanfang01 已提交
35
    注意:LIME可解释性结果可视化目前只支持分类模型。
S
seven 已提交
36

S
sunyanfang01 已提交
37 38 39 40 41
    Args:
        img_file (str): 预测图像路径。
        model (paddlex.cv.models): paddlex中的模型。
        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
        batch_size (int): 预测数据batch大小,默认为50。
S
seven 已提交
42
        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
S
sunyanfang01 已提交
43 44 45 46
    """
    assert model.model_type == 'classifier', \
        'Now the interpretation visualize only be supported in classifier!'
    if model.status != 'Normal':
S
seven 已提交
47 48
        raise Exception(
            'The interpretation only can deal with the Normal model')
S
sunyanfang01 已提交
49 50
    if not osp.exists(save_dir):
        os.makedirs(save_dir)
S
seven 已提交
51
    model.arrange_transforms(transforms=model.test_transforms, mode='test')
S
sunyanfang01 已提交
52 53 54 55 56 57
    tmp_transforms = copy.deepcopy(model.test_transforms)
    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
    img = tmp_transforms(img_file)[0]
    img = np.around(img).astype('uint8')
    img = np.expand_dims(img, axis=0)
    interpreter = None
S
seven 已提交
58 59
    interpreter = get_lime_interpreter(
        img, model, num_samples=num_samples, batch_size=batch_size)
S
sunyanfang01 已提交
60
    img_name = osp.splitext(osp.split(img_file)[-1])[0]
S
seven 已提交
61
    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
S
seven 已提交
62 63 64 65 66 67 68 69 70


def normlime(img_file,
             model,
             dataset=None,
             num_samples=3000,
             batch_size=50,
             save_dir='./',
             normlime_weights_file=None):
S
sunyanfang01 已提交
71
    """使用NormLIME算法将模型预测结果的可解释性可视化。
S
seven 已提交
72

S
sunyanfang01 已提交
73 74 75 76
    NormLIME是利用一定数量的样本来出一个全局的解释。由于NormLIME计算量较大,此处采用一种简化的方式:
    使用一定数量的测试样本(目前默认使用所有测试样本),对每个样本进行特征提取,映射到同一个特征空间;
    然后以此特征做为输入,以模型输出做为输出,使用线性回归对其进行拟合,得到一个全局的输入和输出的关系。
    之后,对一测试样本进行解释时,使用NormLIME全局的解释,来对LIME的结果进行滤波,使最终的可视化结果更加稳定。
S
seven 已提交
77

S
sunyanfang01 已提交
78 79
    注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
    注意2:NormLIME可解释性结果可视化目前只支持分类模型。
S
seven 已提交
80

S
sunyanfang01 已提交
81 82 83 84
    Args:
        img_file (str): 预测图像路径。
        model (paddlex.cv.models): paddlex中的模型。
        dataset (paddlex.datasets): 数据集读取器,默认为None。
S
sunyanfang01 已提交
85
        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
S
sunyanfang01 已提交
86
        batch_size (int): 预测数据batch大小,默认为50。
S
seven 已提交
87 88
        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
        normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。
S
sunyanfang01 已提交
89 90 91
    """
    assert model.model_type == 'classifier', \
        'Now the interpretation visualize only be supported in classifier!'
S
SunAhong1993 已提交
92
    if model.status != 'Normal':
S
seven 已提交
93 94
        raise Exception(
            'The interpretation only can deal with the Normal model')
S
sunyanfang01 已提交
95 96
    if not osp.exists(save_dir):
        os.makedirs(save_dir)
S
seven 已提交
97
    model.arrange_transforms(transforms=model.test_transforms, mode='test')
S
sunyanfang01 已提交
98 99 100 101 102
    tmp_transforms = copy.deepcopy(model.test_transforms)
    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
    img = tmp_transforms(img_file)[0]
    img = np.around(img).astype('uint8')
    img = np.expand_dims(img, axis=0)
S
sunyanfang01 已提交
103
    interpreter = None
S
sunyanfang01 已提交
104
    if dataset is None:
S
seven 已提交
105 106 107 108 109 110 111 112 113 114
        raise Exception(
            'The dataset is None. Cannot implement this kind of interpretation')
    interpreter = get_normlime_interpreter(
        img,
        model,
        dataset,
        num_samples=num_samples,
        batch_size=batch_size,
        save_dir=save_dir,
        normlime_weights_file=normlime_weights_file)
S
sunyanfang01 已提交
115
    img_name = osp.splitext(osp.split(img_file)[-1])[0]
S
seven 已提交
116
    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
S
seven 已提交
117 118


S
sunyanfang01 已提交
119
def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
S
sunyanfang01 已提交
120
    def predict_func(image):
S
sunyanfang01 已提交
121
        out = interpretation_predict(model, image)
S
sunyanfang01 已提交
122
        return out[0]
S
seven 已提交
123

S
SunAhong1993 已提交
124
    labels_name = None
S
sunyanfang01 已提交
125 126
    if hasattr(model, 'labels'):
        labels_name = model.labels
S
seven 已提交
127 128 129 130 131 132
    interpreter = Interpretation(
        'lime',
        predict_func,
        labels_name,
        num_samples=num_samples,
        batch_size=batch_size)
S
sunyanfang01 已提交
133
    return interpreter
S
sunyanfang01 已提交
134 135


S
seven 已提交
136 137 138 139 140 141 142
def get_normlime_interpreter(img,
                             model,
                             dataset,
                             num_samples=3000,
                             batch_size=50,
                             save_dir='./',
                             normlime_weights_file=None):
S
sunyanfang01 已提交
143
    def predict_func(image):
S
sunyanfang01 已提交
144
        out = interpretation_predict(model, image)
S
sunyanfang01 已提交
145
        return out[0]
S
seven 已提交
146

S
SunAhong1993 已提交
147 148 149
    labels_name = None
    if dataset is not None:
        labels_name = dataset.labels
S
sunyanfang01 已提交
150
    root_path = gen_user_home()
S
sunyanfang01 已提交
151
    root_path = osp.join(root_path, '.paddlex')
S
sunyanfang01 已提交
152 153
    pre_models_path = osp.join(root_path, "pre_models")
    if not osp.exists(pre_models_path):
S
sunyanfang01 已提交
154 155
        if not osp.exists(root_path):
            os.makedirs(root_path)
S
sunyanfang01 已提交
156
        url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
S
sunyanfang01 已提交
157
        pdx.utils.download_and_decompress(url, path=root_path)
S
sunyanfang01 已提交
158

S
seven 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171
    if osp.exists(osp.join(save_dir, normlime_weights_file)):
        normlime_weights_file = osp.join(save_dir, normlime_weights_file)
        try:
            np.load(normlime_weights_file, allow_pickle=True).item()
        except:
            normlime_weights_file = precompute_global_classifier(
                dataset,
                predict_func,
                save_path=normlime_weights_file,
                batch_size=batch_size)
    else:
        normlime_weights_file = precompute_global_classifier(
            dataset,
S
sunyanfang01 已提交
172
            predict_func,
S
seven 已提交
173
            save_path=osp.join(save_dir, normlime_weights_file),
S
seven 已提交
174 175 176 177 178 179 180 181 182 183
            batch_size=batch_size)

    interpreter = Interpretation(
        'normlime',
        predict_func,
        labels_name,
        num_samples=num_samples,
        batch_size=batch_size,
        normlime_weights=normlime_weights_file)
    return interpreter