提交 1b0d4d53 编写于 作者: S sunyanfang01

rename lime

上级 5a2ad684
...@@ -28,7 +28,7 @@ from . import seg ...@@ -28,7 +28,7 @@ from . import seg
from . import cls from . import cls
from . import slim from . import slim
from . import tools from . import tools
from . import explanation from . import interpret
try: try:
import pycocotools import pycocotools
......
...@@ -275,7 +275,7 @@ class BaseClassifier(BaseAPI): ...@@ -275,7 +275,7 @@ class BaseClassifier(BaseAPI):
} for l in pred_label] } for l in pred_label]
return res return res
def explanation_predict(self, images): def interpretation_predict(self, images):
self.arrange_transforms( self.arrange_transforms(
transforms=self.test_transforms, mode='test') transforms=self.test_transforms, mode='test')
new_imgs = [] new_imgs = []
......
...@@ -12,31 +12,31 @@ ...@@ -12,31 +12,31 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from .explanation_algorithms import CAM, LIME, NormLIME from .interpretation_algorithms import CAM, LIME, NormLIME
from .normlime_base import precompute_normlime_weights from .normlime_base import precompute_normlime_weights
class Explanation(object): class Interpretation(object):
""" """
Base class for all explanation algorithms. Base class for all interpretation algorithms.
""" """
def __init__(self, explanation_algorithm_name, predict_fn, label_names, **kwargs): def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
supported_algorithms = { supported_algorithms = {
'cam': CAM, 'cam': CAM,
'lime': LIME, 'lime': LIME,
'normlime': NormLIME 'normlime': NormLIME
} }
self.algorithm_name = explanation_algorithm_name.lower() self.algorithm_name = interpretation_algorithm_name.lower()
assert self.algorithm_name in supported_algorithms.keys() assert self.algorithm_name in supported_algorithms.keys()
self.predict_fn = predict_fn self.predict_fn = predict_fn
# initialization for the explanation algorithm. # initialization for the interpretation algorithm.
self.explain_algorithm = supported_algorithms[self.algorithm_name]( self.algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, label_names, **kwargs self.predict_fn, label_names, **kwargs
) )
def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'): def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
""" """
Args: Args:
...@@ -48,4 +48,4 @@ class Explanation(object): ...@@ -48,4 +48,4 @@ class Explanation(object):
Returns: Returns:
""" """
return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir) return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)
...@@ -46,12 +46,13 @@ class CAM(object): ...@@ -46,12 +46,13 @@ class CAM(object):
logit = result[0][0] logit = result[0][0]
if abs(np.sum(logit) - 1.0) > 1e-4: if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax # softmax
logit = logit - np.max(logit)
exp_result = np.exp(logit) exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result) probability = exp_result / np.sum(exp_result)
else: else:
probability = logit probability = logit
# only explain top 1 # only interpret top 1
pred_label = np.argsort(probability) pred_label = np.argsort(probability)
pred_label = pred_label[-1:] pred_label = pred_label[-1:]
...@@ -71,7 +72,7 @@ class CAM(object): ...@@ -71,7 +72,7 @@ class CAM(object):
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}') print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
return feature_maps, fc_weights return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
feature_maps, fc_weights = self.preparation_cam(data_) feature_maps, fc_weights = self.preparation_cam(data_)
cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label) cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
...@@ -123,7 +124,7 @@ class LIME(object): ...@@ -123,7 +124,7 @@ class LIME(object):
self.predict_fn = predict_fn self.predict_fn = predict_fn
self.labels = None self.labels = None
self.image = None self.image = None
self.lime_explainer = None self.lime_interpreter = None
self.label_names = label_names self.label_names = label_names
def preparation_lime(self, data_): def preparation_lime(self, data_):
...@@ -134,12 +135,13 @@ class LIME(object): ...@@ -134,12 +135,13 @@ class LIME(object):
if abs(np.sum(result) - 1.0) > 1e-4: if abs(np.sum(result) - 1.0) > 1e-4:
# softmax # softmax
result = result - np.max(result)
exp_result = np.exp(result) exp_result = np.exp(result)
probability = exp_result / np.sum(exp_result) probability = exp_result / np.sum(exp_result)
else: else:
probability = result probability = result
# only explain top 1 # only interpret top 1
pred_label = np.argsort(probability) pred_label = np.argsort(probability)
pred_label = pred_label[-1:] pred_label = pred_label[-1:]
...@@ -156,14 +158,14 @@ class LIME(object): ...@@ -156,14 +158,14 @@ class LIME(object):
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}') print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
end = time.time() end = time.time()
algo = lime_base.LimeImageExplainer() algo = lime_base.LimeImageInterpreter()
explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0, interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
num_samples=self.num_samples, batch_size=self.batch_size) num_samples=self.num_samples, batch_size=self.batch_size)
self.lime_explainer = explainer self.lime_interpreter = interpreter
print('lime time: ', time.time() - end, 's.') print('lime time: ', time.time() - end, 's.')
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.lime_explainer is None: if self.lime_interpreter is None:
self.preparation_lime(data_) self.preparation_lime(data_)
if visualization or save_to_disk: if visualization or save_to_disk:
...@@ -187,13 +189,13 @@ class LIME(object): ...@@ -187,13 +189,13 @@ class LIME(object):
axes[0].imshow(self.image) axes[0].imshow(self.image)
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}") axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments)) axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
axes[1].set_title("superpixel segmentation") axes[1].set_title("superpixel segmentation")
# LIME visualization # LIME visualization
for i, w in enumerate(weights_choices): for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w) num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
temp, mask = self.lime_explainer.get_image_and_mask( temp, mask = self.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show l, positive_only=False, hide_rest=False, num_features=num_to_show
) )
axes[ncols + i].imshow(mark_boundaries(temp, mask)) axes[ncols + i].imshow(mark_boundaries(temp, mask))
...@@ -274,20 +276,20 @@ class NormLIME(object): ...@@ -274,20 +276,20 @@ class NormLIME(object):
print('performing NormLIME operations ...') print('performing NormLIME operations ...')
cluster_labels = self.predict_cluster_labels( cluster_labels = self.predict_cluster_labels(
compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_explainer.segments compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
) )
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels) g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
return g_weights return g_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.normlime_weights is None: if self.normlime_weights is None:
raise ValueError("Not find the correct precomputed NormLIME result. \n" raise ValueError("Not find the correct precomputed NormLIME result. \n"
"\t Try to call compute_normlime_weights() first or load the correct path.") "\t Try to call compute_normlime_weights() first or load the correct path.")
g_weights = self.preparation_normlime(data_) g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_explainer.local_exp lime_weights = self._lime.lime_interpreter.local_weights
if visualization or save_to_disk: if visualization or save_to_disk:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -312,23 +314,23 @@ class NormLIME(object): ...@@ -312,23 +314,23 @@ class NormLIME(object):
axes[0].imshow(self.image) axes[0].imshow(self.image)
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}") axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments)) axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
axes[1].set_title("superpixel segmentation") axes[1].set_title("superpixel segmentation")
# LIME visualization # LIME visualization
for i, w in enumerate(weights_choices): for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w) num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
nums_to_show.append(num_to_show) nums_to_show.append(num_to_show)
temp, mask = self._lime.lime_explainer.get_image_and_mask( temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show l, positive_only=False, hide_rest=False, num_features=num_to_show
) )
axes[ncols + i].imshow(mark_boundaries(temp, mask)) axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels") axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
# NormLIME visualization # NormLIME visualization
self._lime.lime_explainer.local_exp = g_weights self._lime.lime_interpreter.local_weights = g_weights
for i, num_to_show in enumerate(nums_to_show): for i, num_to_show in enumerate(nums_to_show):
temp, mask = self._lime.lime_explainer.get_image_and_mask( temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show l, positive_only=False, hide_rest=False, num_features=num_to_show
) )
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask)) axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
...@@ -336,15 +338,15 @@ class NormLIME(object): ...@@ -336,15 +338,15 @@ class NormLIME(object):
# NormLIME*LIME visualization # NormLIME*LIME visualization
combined_weights = combine_normlime_and_lime(lime_weights, g_weights) combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
self._lime.lime_explainer.local_exp = combined_weights self._lime.lime_interpreter.local_weights = combined_weights
for i, num_to_show in enumerate(nums_to_show): for i, num_to_show in enumerate(nums_to_show):
temp, mask = self._lime.lime_explainer.get_image_and_mask( temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show l, positive_only=False, hide_rest=False, num_features=num_to_show
) )
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask)) axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels") axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
self._lime.lime_explainer.local_exp = lime_weights self._lime.lime_interpreter.local_weights = lime_weights
if save_to_disk and save_outdir is not None: if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
...@@ -354,9 +356,9 @@ class NormLIME(object): ...@@ -354,9 +356,9 @@ class NormLIME(object):
plt.show() plt.show()
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show): def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
segments = lime_explainer.segments segments = lime_interpreter.segments
lime_weights = lime_explainer.local_exp[label] lime_weights = lime_interpreter.local_weights[label]
num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8 num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
# l1 norm with filtered weights. # l1 norm with filtered weights.
...@@ -381,7 +383,7 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show): ...@@ -381,7 +383,7 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
return 5 return 5
if n == 0: if n == 0:
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1) return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
return n return n
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. """
# Copyright (c) 2016, Marco Tulio Correia Ribeiro
#Licensed under the Apache License, Version 2.0 (the "License"); All rights reserved.
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# http://www.apache.org/licenses/LICENSE-2.0
# * Redistributions of source code must retain the above copyright notice, this
#Unless required by applicable law or agreed to in writing, software list of conditions and the following disclaimer.
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * Redistributions in binary form must reproduce the above copyright notice,
#See the License for the specific language governing permissions and this list of conditions and the following disclaimer in the documentation
#limitations under the License. and/or other materials provided with the distribution.
from __future__ import print_function THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
"""
The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
"""
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import sklearn import sklearn
...@@ -88,7 +103,7 @@ class LimeBase(object): ...@@ -88,7 +103,7 @@ class LimeBase(object):
return np.array(used_features) return np.array(used_features)
def feature_selection(self, data, labels, weights, num_features, method): def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see explain_instance_with_data to """Selects features for the model. see interpret_instance_with_data to
understand the parameters.""" understand the parameters."""
if method == 'none': if method == 'none':
return np.array(range(data.shape[1])) return np.array(range(data.shape[1]))
...@@ -154,15 +169,15 @@ class LimeBase(object): ...@@ -154,15 +169,15 @@ class LimeBase(object):
return self.feature_selection(data, labels, weights, return self.feature_selection(data, labels, weights,
num_features, n_method) num_features, n_method)
def explain_instance_with_data(self, def interpret_instance_with_data(self,
neighborhood_data, neighborhood_data,
neighborhood_labels, neighborhood_labels,
distances, distances,
label, label,
num_features, num_features,
feature_selection='auto', feature_selection='auto',
model_regressor=None): model_regressor=None):
"""Takes perturbed data, labels and distances, returns explanation. """Takes perturbed data, labels and distances, returns interpretation.
Args: Args:
neighborhood_data: perturbed data, 2d array. first element is neighborhood_data: perturbed data, 2d array. first element is
...@@ -170,8 +185,8 @@ class LimeBase(object): ...@@ -170,8 +185,8 @@ class LimeBase(object):
neighborhood_labels: corresponding perturbed labels. should have as neighborhood_labels: corresponding perturbed labels. should have as
many columns as the number of possible labels. many columns as the number of possible labels.
distances: distances to original data point. distances: distances to original data point.
label: label for which we want an explanation label: label for which we want an interpretation
num_features: maximum number of features in explanation num_features: maximum number of features in interpretation
feature_selection: how to select num_features. options are: feature_selection: how to select num_features. options are:
'forward_selection': iteratively add features to the model. 'forward_selection': iteratively add features to the model.
This is costly when num_features is high This is costly when num_features is high
...@@ -183,7 +198,7 @@ class LimeBase(object): ...@@ -183,7 +198,7 @@ class LimeBase(object):
'none': uses all features, ignores num_features 'none': uses all features, ignores num_features
'auto': uses forward_selection if num_features <= 6, and 'auto': uses forward_selection if num_features <= 6, and
'highest_weights' otherwise. 'highest_weights' otherwise.
model_regressor: sklearn regressor to use in explanation. model_regressor: sklearn regressor to use in interpretation.
Defaults to Ridge regression if None. Must have Defaults to Ridge regression if None. Must have
model_regressor.coef_ and 'sample_weight' as a parameter model_regressor.coef_ and 'sample_weight' as a parameter
to model_regressor.fit() to model_regressor.fit()
...@@ -194,8 +209,8 @@ class LimeBase(object): ...@@ -194,8 +209,8 @@ class LimeBase(object):
exp is a sorted list of tuples, where each tuple (x,y) corresponds exp is a sorted list of tuples, where each tuple (x,y) corresponds
to the feature id (x) and the local weight (y). The list is sorted to the feature id (x) and the local weight (y). The list is sorted
by decreasing absolute value of y. by decreasing absolute value of y.
score is the R^2 value of the returned explanation score is the R^2 value of the returned interpretation
local_pred is the prediction of the explanation model on the original instance local_pred is the prediction of the interpretation model on the original instance
""" """
weights = self.kernel_fn(distances) weights = self.kernel_fn(distances)
...@@ -227,7 +242,7 @@ class LimeBase(object): ...@@ -227,7 +242,7 @@ class LimeBase(object):
prediction_score, local_pred) prediction_score, local_pred)
class ImageExplanation(object): class ImageInterpretation(object):
def __init__(self, image, segments): def __init__(self, image, segments):
"""Init function. """Init function.
...@@ -238,7 +253,7 @@ class ImageExplanation(object): ...@@ -238,7 +253,7 @@ class ImageExplanation(object):
self.image = image self.image = image
self.segments = segments self.segments = segments
self.intercept = {} self.intercept = {}
self.local_exp = {} self.local_weights = {}
self.local_pred = None self.local_pred = None
def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
...@@ -246,40 +261,40 @@ class ImageExplanation(object): ...@@ -246,40 +261,40 @@ class ImageExplanation(object):
"""Init function. """Init function.
Args: Args:
label: label to explain label: label to interpret
positive_only: if True, only take superpixels that positively contribute to positive_only: if True, only take superpixels that positively contribute to
the prediction of the label. the prediction of the label.
negative_only: if True, only take superpixels that negatively contribute to negative_only: if True, only take superpixels that negatively contribute to
the prediction of the label. If false, and so is positive_only, then both the prediction of the label. If false, and so is positive_only, then both
negativey and positively contributions will be taken. negativey and positively contributions will be taken.
Both can't be True at the same time Both can't be True at the same time
hide_rest: if True, make the non-explanation part of the return hide_rest: if True, make the non-interpretation part of the return
image gray image gray
num_features: number of superpixels to include in explanation num_features: number of superpixels to include in interpretation
min_weight: minimum weight of the superpixels to include in explanation min_weight: minimum weight of the superpixels to include in interpretation
Returns: Returns:
(image, mask), where image is a 3d numpy array and mask is a 2d (image, mask), where image is a 3d numpy array and mask is a 2d
numpy array that can be used with numpy array that can be used with
skimage.segmentation.mark_boundaries skimage.segmentation.mark_boundaries
""" """
if label not in self.local_exp: if label not in self.local_weights:
raise KeyError('Label not in explanation') raise KeyError('Label not in interpretation')
if positive_only & negative_only: if positive_only & negative_only:
raise ValueError("Positive_only and negative_only cannot be true at the same time.") raise ValueError("Positive_only and negative_only cannot be true at the same time.")
segments = self.segments segments = self.segments
image = self.image image = self.image
exp = self.local_exp[label] local_weights_label = self.local_weights[label]
mask = np.zeros(segments.shape, segments.dtype) mask = np.zeros(segments.shape, segments.dtype)
if hide_rest: if hide_rest:
temp = np.zeros(self.image.shape) temp = np.zeros(self.image.shape)
else: else:
temp = self.image.copy() temp = self.image.copy()
if positive_only: if positive_only:
fs = [x[0] for x in exp fs = [x[0] for x in local_weights_label
if x[1] > 0 and x[1] > min_weight][:num_features] if x[1] > 0 and x[1] > min_weight][:num_features]
if negative_only: if negative_only:
fs = [x[0] for x in exp fs = [x[0] for x in local_weights_label
if x[1] < 0 and abs(x[1]) > min_weight][:num_features] if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
if positive_only or negative_only: if positive_only or negative_only:
for f in fs: for f in fs:
...@@ -287,7 +302,7 @@ class ImageExplanation(object): ...@@ -287,7 +302,7 @@ class ImageExplanation(object):
mask[segments == f] = 1 mask[segments == f] = 1
return temp, mask return temp, mask
else: else:
for f, w in exp[:num_features]: for f, w in local_weights_label[:num_features]:
if np.abs(w) < min_weight: if np.abs(w) < min_weight:
continue continue
c = 0 if w < 0 else 1 c = 0 if w < 0 else 1
...@@ -300,32 +315,31 @@ class ImageExplanation(object): ...@@ -300,32 +315,31 @@ class ImageExplanation(object):
""" """
Args: Args:
label: label to explain label: label to interpret
min_weight: min_weight:
Returns: Returns:
image, is a 3d numpy array image, is a 3d numpy array
""" """
if label not in self.local_exp: if label not in self.local_weights:
raise KeyError('Label not in explanation') raise KeyError('Label not in interpretation')
from matplotlib import cm from matplotlib import cm
segments = self.segments segments = self.segments
image = self.image image = self.image
exp = self.local_exp[label] local_weights_label = self.local_weights[label]
temp = np.zeros_like(image) temp = np.zeros_like(image)
weight_max = abs(exp[0][1]) weight_max = abs(local_weights_label[0][1])
exp = [(f, w/weight_max) for f, w in exp] local_weights_label = [(f, w/weight_max) for f, w in local_weights_label]
exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last. local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True) # negatives are at last.
cmaps = cm.get_cmap('Spectral') cmaps = cm.get_cmap('Spectral')
# sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp)))) colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
colors = cmaps(np.linspace(0, 1, len(exp)))
colors = colors[:, :3] colors = colors[:, :3]
for i, (f, w) in enumerate(exp): for i, (f, w) in enumerate(local_weights_label):
if np.abs(w) < min_weight: if np.abs(w) < min_weight:
continue continue
temp[segments == f] = image[segments == f].copy() temp[segments == f] = image[segments == f].copy()
...@@ -333,14 +347,14 @@ class ImageExplanation(object): ...@@ -333,14 +347,14 @@ class ImageExplanation(object):
return temp return temp
class LimeImageExplainer(object): class LimeImageInterpreter(object):
"""Explains predictions on Image (i.e. matrix) data. """Interpres predictions on Image (i.e. matrix) data.
For numerical features, perturb them by sampling from a Normal(0,1) and For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data. For categorical features, perturb by means and stds in the training data. For categorical features, perturb by
sampling according to the training distribution, and making a binary sampling according to the training distribution, and making a binary
feature that is 1 when the value is the same as the instance being feature that is 1 when the value is the same as the instance being
explained.""" interpreted."""
def __init__(self, kernel_width=.25, kernel=None, verbose=False, def __init__(self, kernel_width=.25, kernel=None, verbose=False,
feature_selection='auto', random_state=None): feature_selection='auto', random_state=None):
...@@ -355,7 +369,7 @@ class LimeImageExplainer(object): ...@@ -355,7 +369,7 @@ class LimeImageExplainer(object):
verbose: if true, print local prediction values from linear model verbose: if true, print local prediction values from linear model
feature_selection: feature selection method. can be feature_selection: feature selection method. can be
'forward_selection', 'lasso_path', 'none' or 'auto'. 'forward_selection', 'lasso_path', 'none' or 'auto'.
See function 'explain_instance_with_data' in lime_base.py for See function 'einterpret_instance_with_data' in lime_base.py for
details on what each of the options does. details on what each of the options does.
random_state: an integer or numpy.RandomState that will be used to random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be generate random numbers. If None, the random state will be
...@@ -373,18 +387,18 @@ class LimeImageExplainer(object): ...@@ -373,18 +387,18 @@ class LimeImageExplainer(object):
self.feature_selection = feature_selection self.feature_selection = feature_selection
self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state) self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
def explain_instance(self, image, classifier_fn, labels=(1,), def interpret_instance(self, image, classifier_fn, labels=(1,),
hide_color=None, hide_color=None,
num_features=100000, num_samples=1000, num_features=100000, num_samples=1000,
batch_size=10, batch_size=10,
distance_metric='cosine', distance_metric='cosine',
model_regressor=None model_regressor=None
): ):
"""Generates explanations for a prediction. """Generates interpretations for a prediction.
First, we generate neighborhood data by randomly perturbing features First, we generate neighborhood data by randomly perturbing features
from the instance (see __data_inverse). We then learn locally weighted from the instance (see __data_inverse). We then learn locally weighted
linear models on this neighborhood data to explain each of the classes linear models on this neighborhood data to interpret each of the classes
in an interpretable way (see lime_base.py). in an interpretable way (see lime_base.py).
Args: Args:
...@@ -393,19 +407,19 @@ class LimeImageExplainer(object): ...@@ -393,19 +407,19 @@ class LimeImageExplainer(object):
classifier_fn: classifier prediction probability function, which classifier_fn: classifier prediction probability function, which
takes a numpy array and outputs prediction probabilities. For takes a numpy array and outputs prediction probabilities. For
ScikitClassifiers , this is classifier.predict_proba. ScikitClassifiers , this is classifier.predict_proba.
labels: iterable with labels to be explained. labels: iterable with labels to be interpreted.
hide_color: TODO hide_color: TODO
num_features: maximum number of features present in explanation num_features: maximum number of features present in interpretation
num_samples: size of the neighborhood to learn the linear model num_samples: size of the neighborhood to learn the linear model
batch_size: TODO batch_size: TODO
distance_metric: the distance metric to use for weights. distance_metric: the distance metric to use for weights.
model_regressor: sklearn regressor to use in explanation. Defaults model_regressor: sklearn regressor to use in interpretation. Defaults
to Ridge regression in LimeBase. Must have model_regressor.coef_ to Ridge regression in LimeBase. Must have model_regressor.coef_
and 'sample_weight' as a parameter to model_regressor.fit() and 'sample_weight' as a parameter to model_regressor.fit()
Returns: Returns:
An ImageExplanation object (see lime_image.py) with the corresponding An ImageIinterpretation object (see lime_image.py) with the corresponding
explanations. interpretations.
""" """
if len(image.shape) == 2: if len(image.shape) == 2:
image = gray2rgb(image) image = gray2rgb(image)
...@@ -455,15 +469,15 @@ class LimeImageExplainer(object): ...@@ -455,15 +469,15 @@ class LimeImageExplainer(object):
metric=distance_metric metric=distance_metric
).ravel() ).ravel()
ret_exp = ImageExplanation(image, segments) interpretation_image = ImageInterpretation(image, segments)
for label in top: for label in top:
(ret_exp.intercept[label], (interpretation_image.intercept[label],
ret_exp.local_exp[label], interpretation_image.local_weights[label],
ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data(
data, labels, distances, label, num_features, data, labels, distances, label, num_features,
model_regressor=model_regressor, model_regressor=model_regressor,
feature_selection=self.feature_selection) feature_selection=self.feature_selection)
return ret_exp return interpretation_image
def data_labels(self, def data_labels(self,
image, image,
......
...@@ -87,11 +87,11 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_ ...@@ -87,11 +87,11 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_
return compute_normlime_weights(fname_list, save_dir, num_samples) return compute_normlime_weights(fname_list, save_dir, num_samples)
def save_one_lime_predict_and_kmean_labels(lime_exp_all_weights, image_pred_labels, cluster_labels, save_path): def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
lime_weights = {} lime_weights = {}
for label in image_pred_labels: for label in image_pred_labels:
lime_weights[label] = lime_exp_all_weights[label] lime_weights[label] = lime_all_weights[label]
for_normlime_weights = { for_normlime_weights = {
'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight) 'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight)
...@@ -145,15 +145,15 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -145,15 +145,15 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
pred_label = pred_label[:top_k] pred_label = pred_label[:top_k]
algo = lime_base.LimeImageExplainer() algo = lime_base.LimeImageInterpreter()
explainer = algo.explain_instance(image_show[0], predict_fn, pred_label, 0, interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
num_samples=num_samples, batch_size=batch_size) num_samples=num_samples, batch_size=batch_size)
cluster_labels = kmeans_model.predict( cluster_labels = kmeans_model.predict(
get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), explainer.segments) get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
) )
save_one_lime_predict_and_kmean_labels( save_one_lime_predict_and_kmean_labels(
explainer.local_exp, pred_label, interpreter.local_weights, pred_label,
cluster_labels, cluster_labels,
save_path save_path
) )
......
...@@ -17,19 +17,19 @@ import cv2 ...@@ -17,19 +17,19 @@ import cv2
import copy import copy
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from .core.explanation import Explanation from .core.interpretation import Interpretation
from .core.normlime_base import precompute_normlime_weights from .core.normlime_base import precompute_normlime_weights
def visualize(img_file, def visualize(img_file,
model, model,
dataset=None, dataset=None,
explanation_type='lime', algo='lime',
num_samples=3000, num_samples=3000,
batch_size=50, batch_size=50,
save_dir='./'): save_dir='./'):
if model.status != 'Normal': if model.status != 'Normal':
raise Exception('The explanation only can deal with the Normal model') raise Exception('The interpretation only can deal with the Normal model')
model.arrange_transforms( model.arrange_transforms(
transforms=model.test_transforms, mode='test') transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
...@@ -37,48 +37,48 @@ def visualize(img_file, ...@@ -37,48 +37,48 @@ def visualize(img_file,
img = tmp_transforms(img_file)[0] img = tmp_transforms(img_file)[0]
img = np.around(img).astype('uint8') img = np.around(img).astype('uint8')
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
explaier = None interpreter = None
if explanation_type == 'lime': if algo == 'lime':
explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size) interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
elif explanation_type == 'normlime': elif algo == 'normlime':
if dataset is None: if dataset is None:
raise Exception('The dataset is None. Cannot implement this kind of explanation') raise Exception('The dataset is None. Cannot implement this kind of interpretation')
explaier = get_normlime_explaier(img, model, dataset, interpreter = get_normlime_interpreter(img, model, dataset,
num_samples=num_samples, batch_size=batch_size, num_samples=num_samples, batch_size=batch_size,
save_dir=save_dir) save_dir=save_dir)
else: else:
raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type)) raise Exception('The {} interpretation method is not supported yet!'.format(algo))
img_name = osp.splitext(osp.split(img_file)[-1])[0] img_name = osp.splitext(osp.split(img_file)[-1])[0]
explaier.explain(img, save_dir=save_dir) interpreter.interpret(img, save_dir=save_dir)
def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50): def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
def predict_func(image): def predict_func(image):
image = image.astype('float32') image = image.astype('float32')
for i in range(image.shape[0]): for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms) tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
if dataset is not None: if dataset is not None:
labels_name = dataset.labels labels_name = dataset.labels
explaier = Explanation('lime', interpreter = Interpretation('lime',
predict_func, predict_func,
labels_name, labels_name,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size) batch_size=batch_size)
return explaier return interpreter
def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'): def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image): def precompute_predict_func(image):
image = image.astype('float32') image = image.astype('float32')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms) tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
def predict_func(image): def predict_func(image):
...@@ -87,7 +87,7 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, ...@@ -87,7 +87,7 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms) tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
...@@ -105,13 +105,13 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, ...@@ -105,13 +105,13 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size, batch_size=batch_size,
save_dir=save_dir) save_dir=save_dir)
explaier = Explanation('normlime', interpreter = Interpretation('normlime',
predict_func, predict_func,
labels_name, labels_name,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size, batch_size=batch_size,
normlime_weights=npy_dir) normlime_weights=npy_dir)
return explaier return interpreter
def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'): def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from .cv.models.explanation import visualize from .cv.models.interpret import visualize
visualize = visualize.visualize visualize = visualize.visualize
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册