visualize.py 8.0 KB
Newer Older
F
FlyingQianMM 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
J
jiangjiajun 已提交
2
#
F
FlyingQianMM 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
J
jiangjiajun 已提交
6
#
F
FlyingQianMM 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
J
jiangjiajun 已提交
8
#
F
FlyingQianMM 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
sunyanfang01 已提交
14 15 16 17 18 19

import os
import cv2
import copy
import os.path as osp
import numpy as np
S
sunyanfang01 已提交
20
import paddlex as pdx
S
sunyanfang01 已提交
21
from .interpretation_predict import interpretation_predict
S
sunyanfang01 已提交
22
from .core.interpretation import Interpretation
S
seven 已提交
23
from .core.normlime_base import precompute_global_classifier
S
sunyanfang01 已提交
24
from .core._session_preparation import gen_user_home
J
jiangjiajun 已提交
25
from paddlex.cv.transforms import arrange_transforms
S
seven 已提交
26 27 28 29 30


def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
    """使用LIME算法将模型预测结果的可解释性可视化。

S
sunyanfang01 已提交
31 32 33
    LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
    在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
    和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
S
seven 已提交
34 35
    得到每个输入维度的权重,以此来解释模型。

S
sunyanfang01 已提交
36
    注意:LIME可解释性结果可视化目前只支持分类模型。
S
seven 已提交
37

S
sunyanfang01 已提交
38 39 40 41 42
    Args:
        img_file (str): 预测图像路径。
        model (paddlex.cv.models): paddlex中的模型。
        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
        batch_size (int): 预测数据batch大小,默认为50。
S
seven 已提交
43
        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
S
sunyanfang01 已提交
44 45 46 47
    """
    assert model.model_type == 'classifier', \
        'Now the interpretation visualize only be supported in classifier!'
    if model.status != 'Normal':
S
seven 已提交
48 49
        raise Exception(
            'The interpretation only can deal with the Normal model')
S
sunyanfang01 已提交
50 51
    if not osp.exists(save_dir):
        os.makedirs(save_dir)
J
jiangjiajun 已提交
52 53 54 55 56
    arrange_transforms(
        model.model_type,
        model.__class__.__name__,
        transforms=model.test_transforms,
        mode='test')
S
sunyanfang01 已提交
57 58 59 60 61 62
    tmp_transforms = copy.deepcopy(model.test_transforms)
    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
    img = tmp_transforms(img_file)[0]
    img = np.around(img).astype('uint8')
    img = np.expand_dims(img, axis=0)
    interpreter = None
S
seven 已提交
63 64
    interpreter = get_lime_interpreter(
        img, model, num_samples=num_samples, batch_size=batch_size)
S
sunyanfang01 已提交
65
    img_name = osp.splitext(osp.split(img_file)[-1])[0]
S
seven 已提交
66
    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
S
seven 已提交
67 68 69 70 71 72 73 74 75


def normlime(img_file,
             model,
             dataset=None,
             num_samples=3000,
             batch_size=50,
             save_dir='./',
             normlime_weights_file=None):
S
sunyanfang01 已提交
76
    """使用NormLIME算法将模型预测结果的可解释性可视化。
S
seven 已提交
77

S
sunyanfang01 已提交
78 79 80 81
    NormLIME是利用一定数量的样本来出一个全局的解释。由于NormLIME计算量较大,此处采用一种简化的方式:
    使用一定数量的测试样本(目前默认使用所有测试样本),对每个样本进行特征提取,映射到同一个特征空间;
    然后以此特征做为输入,以模型输出做为输出,使用线性回归对其进行拟合,得到一个全局的输入和输出的关系。
    之后,对一测试样本进行解释时,使用NormLIME全局的解释,来对LIME的结果进行滤波,使最终的可视化结果更加稳定。
S
seven 已提交
82

S
sunyanfang01 已提交
83 84
    注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
    注意2:NormLIME可解释性结果可视化目前只支持分类模型。
S
seven 已提交
85

S
sunyanfang01 已提交
86 87 88 89
    Args:
        img_file (str): 预测图像路径。
        model (paddlex.cv.models): paddlex中的模型。
        dataset (paddlex.datasets): 数据集读取器,默认为None。
S
sunyanfang01 已提交
90
        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
S
sunyanfang01 已提交
91
        batch_size (int): 预测数据batch大小,默认为50。
S
seven 已提交
92 93
        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
        normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。
S
sunyanfang01 已提交
94 95 96
    """
    assert model.model_type == 'classifier', \
        'Now the interpretation visualize only be supported in classifier!'
S
SunAhong1993 已提交
97
    if model.status != 'Normal':
S
seven 已提交
98 99
        raise Exception(
            'The interpretation only can deal with the Normal model')
S
sunyanfang01 已提交
100 101
    if not osp.exists(save_dir):
        os.makedirs(save_dir)
J
jiangjiajun 已提交
102 103 104 105 106
    arrange_transforms(
        model.model_type,
        model.__class__.__name__,
        transforms=model.test_transforms,
        mode='test')
S
sunyanfang01 已提交
107 108 109 110 111
    tmp_transforms = copy.deepcopy(model.test_transforms)
    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
    img = tmp_transforms(img_file)[0]
    img = np.around(img).astype('uint8')
    img = np.expand_dims(img, axis=0)
S
sunyanfang01 已提交
112
    interpreter = None
S
sunyanfang01 已提交
113
    if dataset is None:
S
seven 已提交
114 115 116 117 118 119 120 121 122 123
        raise Exception(
            'The dataset is None. Cannot implement this kind of interpretation')
    interpreter = get_normlime_interpreter(
        img,
        model,
        dataset,
        num_samples=num_samples,
        batch_size=batch_size,
        save_dir=save_dir,
        normlime_weights_file=normlime_weights_file)
S
sunyanfang01 已提交
124
    img_name = osp.splitext(osp.split(img_file)[-1])[0]
S
seven 已提交
125
    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
S
seven 已提交
126 127


S
sunyanfang01 已提交
128
def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
S
sunyanfang01 已提交
129
    def predict_func(image):
S
sunyanfang01 已提交
130
        out = interpretation_predict(model, image)
S
sunyanfang01 已提交
131
        return out[0]
S
seven 已提交
132

S
SunAhong1993 已提交
133
    labels_name = None
S
sunyanfang01 已提交
134 135
    if hasattr(model, 'labels'):
        labels_name = model.labels
S
seven 已提交
136 137 138 139 140 141
    interpreter = Interpretation(
        'lime',
        predict_func,
        labels_name,
        num_samples=num_samples,
        batch_size=batch_size)
S
sunyanfang01 已提交
142
    return interpreter
S
sunyanfang01 已提交
143 144


S
seven 已提交
145 146 147 148 149 150 151
def get_normlime_interpreter(img,
                             model,
                             dataset,
                             num_samples=3000,
                             batch_size=50,
                             save_dir='./',
                             normlime_weights_file=None):
S
sunyanfang01 已提交
152
    def predict_func(image):
S
sunyanfang01 已提交
153
        out = interpretation_predict(model, image)
S
sunyanfang01 已提交
154
        return out[0]
S
seven 已提交
155

S
SunAhong1993 已提交
156 157 158
    labels_name = None
    if dataset is not None:
        labels_name = dataset.labels
S
sunyanfang01 已提交
159
    root_path = gen_user_home()
S
sunyanfang01 已提交
160
    root_path = osp.join(root_path, '.paddlex')
S
sunyanfang01 已提交
161 162
    pre_models_path = osp.join(root_path, "pre_models")
    if not osp.exists(pre_models_path):
S
sunyanfang01 已提交
163 164
        if not osp.exists(root_path):
            os.makedirs(root_path)
S
sunyanfang01 已提交
165
        url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
S
sunyanfang01 已提交
166
        pdx.utils.download_and_decompress(url, path=root_path)
S
sunyanfang01 已提交
167

S
seven 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180
    if osp.exists(osp.join(save_dir, normlime_weights_file)):
        normlime_weights_file = osp.join(save_dir, normlime_weights_file)
        try:
            np.load(normlime_weights_file, allow_pickle=True).item()
        except:
            normlime_weights_file = precompute_global_classifier(
                dataset,
                predict_func,
                save_path=normlime_weights_file,
                batch_size=batch_size)
    else:
        normlime_weights_file = precompute_global_classifier(
            dataset,
S
sunyanfang01 已提交
181
            predict_func,
S
seven 已提交
182
            save_path=osp.join(save_dir, normlime_weights_file),
S
seven 已提交
183 184 185 186 187 188 189 190 191 192
            batch_size=batch_size)

    interpreter = Interpretation(
        'normlime',
        predict_func,
        labels_name,
        num_samples=num_samples,
        batch_size=batch_size,
        normlime_weights=normlime_weights_file)
    return interpreter