提交 1b0d4d53 编写于 作者: S sunyanfang01

rename lime

上级 5a2ad684
......@@ -28,7 +28,7 @@ from . import seg
from . import cls
from . import slim
from . import tools
from . import explanation
from . import interpret
try:
import pycocotools
......
......@@ -275,7 +275,7 @@ class BaseClassifier(BaseAPI):
} for l in pred_label]
return res
def explanation_predict(self, images):
def interpretation_predict(self, images):
self.arrange_transforms(
transforms=self.test_transforms, mode='test')
new_imgs = []
......
......@@ -12,31 +12,31 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from .explanation_algorithms import CAM, LIME, NormLIME
from .interpretation_algorithms import CAM, LIME, NormLIME
from .normlime_base import precompute_normlime_weights
class Explanation(object):
class Interpretation(object):
"""
Base class for all explanation algorithms.
Base class for all interpretation algorithms.
"""
def __init__(self, explanation_algorithm_name, predict_fn, label_names, **kwargs):
def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
supported_algorithms = {
'cam': CAM,
'lime': LIME,
'normlime': NormLIME
}
self.algorithm_name = explanation_algorithm_name.lower()
self.algorithm_name = interpretation_algorithm_name.lower()
assert self.algorithm_name in supported_algorithms.keys()
self.predict_fn = predict_fn
# initialization for the explanation algorithm.
self.explain_algorithm = supported_algorithms[self.algorithm_name](
# initialization for the interpretation algorithm.
self.algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, label_names, **kwargs
)
def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
"""
Args:
......@@ -48,4 +48,4 @@ class Explanation(object):
Returns:
"""
return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)
......@@ -46,12 +46,13 @@ class CAM(object):
logit = result[0][0]
if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax
logit = logit - np.max(logit)
exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result)
else:
probability = logit
# only explain top 1
# only interpret top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
......@@ -71,7 +72,7 @@ class CAM(object):
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
feature_maps, fc_weights = self.preparation_cam(data_)
cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
......@@ -123,7 +124,7 @@ class LIME(object):
self.predict_fn = predict_fn
self.labels = None
self.image = None
self.lime_explainer = None
self.lime_interpreter = None
self.label_names = label_names
def preparation_lime(self, data_):
......@@ -134,12 +135,13 @@ class LIME(object):
if abs(np.sum(result) - 1.0) > 1e-4:
# softmax
result = result - np.max(result)
exp_result = np.exp(result)
probability = exp_result / np.sum(exp_result)
else:
probability = result
# only explain top 1
# only interpret top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
......@@ -156,14 +158,14 @@ class LIME(object):
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
end = time.time()
algo = lime_base.LimeImageExplainer()
explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0,
num_samples=self.num_samples, batch_size=self.batch_size)
self.lime_explainer = explainer
algo = lime_base.LimeImageInterpreter()
interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
num_samples=self.num_samples, batch_size=self.batch_size)
self.lime_interpreter = interpreter
print('lime time: ', time.time() - end, 's.')
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.lime_explainer is None:
def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.lime_interpreter is None:
self.preparation_lime(data_)
if visualization or save_to_disk:
......@@ -187,13 +189,13 @@ class LIME(object):
axes[0].imshow(self.image)
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w)
temp, mask = self.lime_explainer.get_image_and_mask(
num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
temp, mask = self.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
......@@ -274,20 +276,20 @@ class NormLIME(object):
print('performing NormLIME operations ...')
cluster_labels = self.predict_cluster_labels(
compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_explainer.segments
compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
)
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
return g_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.normlime_weights is None:
raise ValueError("Not find the correct precomputed NormLIME result. \n"
"\t Try to call compute_normlime_weights() first or load the correct path.")
g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_explainer.local_exp
lime_weights = self._lime.lime_interpreter.local_weights
if visualization or save_to_disk:
import matplotlib.pyplot as plt
......@@ -312,23 +314,23 @@ class NormLIME(object):
axes[0].imshow(self.image)
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
nums_to_show.append(num_to_show)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
# NormLIME visualization
self._lime.lime_explainer.local_exp = g_weights
self._lime.lime_interpreter.local_weights = g_weights
for i, num_to_show in enumerate(nums_to_show):
temp, mask = self._lime.lime_explainer.get_image_and_mask(
temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
......@@ -336,15 +338,15 @@ class NormLIME(object):
# NormLIME*LIME visualization
combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
self._lime.lime_explainer.local_exp = combined_weights
self._lime.lime_interpreter.local_weights = combined_weights
for i, num_to_show in enumerate(nums_to_show):
temp, mask = self._lime.lime_explainer.get_image_and_mask(
temp, mask = self._lime.lime_interpreter.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
self._lime.lime_explainer.local_exp = lime_weights
self._lime.lime_interpreter.local_weights = lime_weights
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
......@@ -354,9 +356,9 @@ class NormLIME(object):
plt.show()
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
segments = lime_explainer.segments
lime_weights = lime_explainer.local_exp[label]
def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
segments = lime_interpreter.segments
lime_weights = lime_interpreter.local_weights[label]
num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
# l1 norm with filtered weights.
......@@ -381,7 +383,7 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
return 5
if n == 0:
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
return n
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import print_function
"""
Copyright (c) 2016, Marco Tulio Correia Ribeiro
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
"""
The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
"""
import numpy as np
import scipy as sp
import sklearn
......@@ -88,7 +103,7 @@ class LimeBase(object):
return np.array(used_features)
def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see explain_instance_with_data to
"""Selects features for the model. see interpret_instance_with_data to
understand the parameters."""
if method == 'none':
return np.array(range(data.shape[1]))
......@@ -154,15 +169,15 @@ class LimeBase(object):
return self.feature_selection(data, labels, weights,
num_features, n_method)
def explain_instance_with_data(self,
neighborhood_data,
neighborhood_labels,
distances,
label,
num_features,
feature_selection='auto',
model_regressor=None):
"""Takes perturbed data, labels and distances, returns explanation.
def interpret_instance_with_data(self,
neighborhood_data,
neighborhood_labels,
distances,
label,
num_features,
feature_selection='auto',
model_regressor=None):
"""Takes perturbed data, labels and distances, returns interpretation.
Args:
neighborhood_data: perturbed data, 2d array. first element is
......@@ -170,8 +185,8 @@ class LimeBase(object):
neighborhood_labels: corresponding perturbed labels. should have as
many columns as the number of possible labels.
distances: distances to original data point.
label: label for which we want an explanation
num_features: maximum number of features in explanation
label: label for which we want an interpretation
num_features: maximum number of features in interpretation
feature_selection: how to select num_features. options are:
'forward_selection': iteratively add features to the model.
This is costly when num_features is high
......@@ -183,7 +198,7 @@ class LimeBase(object):
'none': uses all features, ignores num_features
'auto': uses forward_selection if num_features <= 6, and
'highest_weights' otherwise.
model_regressor: sklearn regressor to use in explanation.
model_regressor: sklearn regressor to use in interpretation.
Defaults to Ridge regression if None. Must have
model_regressor.coef_ and 'sample_weight' as a parameter
to model_regressor.fit()
......@@ -194,8 +209,8 @@ class LimeBase(object):
exp is a sorted list of tuples, where each tuple (x,y) corresponds
to the feature id (x) and the local weight (y). The list is sorted
by decreasing absolute value of y.
score is the R^2 value of the returned explanation
local_pred is the prediction of the explanation model on the original instance
score is the R^2 value of the returned interpretation
local_pred is the prediction of the interpretation model on the original instance
"""
weights = self.kernel_fn(distances)
......@@ -227,7 +242,7 @@ class LimeBase(object):
prediction_score, local_pred)
class ImageExplanation(object):
class ImageInterpretation(object):
def __init__(self, image, segments):
"""Init function.
......@@ -238,7 +253,7 @@ class ImageExplanation(object):
self.image = image
self.segments = segments
self.intercept = {}
self.local_exp = {}
self.local_weights = {}
self.local_pred = None
def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
......@@ -246,40 +261,40 @@ class ImageExplanation(object):
"""Init function.
Args:
label: label to explain
label: label to interpret
positive_only: if True, only take superpixels that positively contribute to
the prediction of the label.
negative_only: if True, only take superpixels that negatively contribute to
the prediction of the label. If false, and so is positive_only, then both
negativey and positively contributions will be taken.
Both can't be True at the same time
hide_rest: if True, make the non-explanation part of the return
hide_rest: if True, make the non-interpretation part of the return
image gray
num_features: number of superpixels to include in explanation
min_weight: minimum weight of the superpixels to include in explanation
num_features: number of superpixels to include in interpretation
min_weight: minimum weight of the superpixels to include in interpretation
Returns:
(image, mask), where image is a 3d numpy array and mask is a 2d
numpy array that can be used with
skimage.segmentation.mark_boundaries
"""
if label not in self.local_exp:
raise KeyError('Label not in explanation')
if label not in self.local_weights:
raise KeyError('Label not in interpretation')
if positive_only & negative_only:
raise ValueError("Positive_only and negative_only cannot be true at the same time.")
segments = self.segments
image = self.image
exp = self.local_exp[label]
local_weights_label = self.local_weights[label]
mask = np.zeros(segments.shape, segments.dtype)
if hide_rest:
temp = np.zeros(self.image.shape)
else:
temp = self.image.copy()
if positive_only:
fs = [x[0] for x in exp
fs = [x[0] for x in local_weights_label
if x[1] > 0 and x[1] > min_weight][:num_features]
if negative_only:
fs = [x[0] for x in exp
fs = [x[0] for x in local_weights_label
if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
if positive_only or negative_only:
for f in fs:
......@@ -287,7 +302,7 @@ class ImageExplanation(object):
mask[segments == f] = 1
return temp, mask
else:
for f, w in exp[:num_features]:
for f, w in local_weights_label[:num_features]:
if np.abs(w) < min_weight:
continue
c = 0 if w < 0 else 1
......@@ -300,32 +315,31 @@ class ImageExplanation(object):
"""
Args:
label: label to explain
label: label to interpret
min_weight:
Returns:
image, is a 3d numpy array
"""
if label not in self.local_exp:
raise KeyError('Label not in explanation')
if label not in self.local_weights:
raise KeyError('Label not in interpretation')
from matplotlib import cm
segments = self.segments
image = self.image
exp = self.local_exp[label]
local_weights_label = self.local_weights[label]
temp = np.zeros_like(image)
weight_max = abs(exp[0][1])
exp = [(f, w/weight_max) for f, w in exp]
exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last.
weight_max = abs(local_weights_label[0][1])
local_weights_label = [(f, w/weight_max) for f, w in local_weights_label]
local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True) # negatives are at last.
cmaps = cm.get_cmap('Spectral')
# sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
colors = cmaps(np.linspace(0, 1, len(exp)))
colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
colors = colors[:, :3]
for i, (f, w) in enumerate(exp):
for i, (f, w) in enumerate(local_weights_label):
if np.abs(w) < min_weight:
continue
temp[segments == f] = image[segments == f].copy()
......@@ -333,14 +347,14 @@ class ImageExplanation(object):
return temp
class LimeImageExplainer(object):
"""Explains predictions on Image (i.e. matrix) data.
class LimeImageInterpreter(object):
"""Interpres predictions on Image (i.e. matrix) data.
For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data. For categorical features, perturb by
sampling according to the training distribution, and making a binary
feature that is 1 when the value is the same as the instance being
explained."""
interpreted."""
def __init__(self, kernel_width=.25, kernel=None, verbose=False,
feature_selection='auto', random_state=None):
......@@ -355,7 +369,7 @@ class LimeImageExplainer(object):
verbose: if true, print local prediction values from linear model
feature_selection: feature selection method. can be
'forward_selection', 'lasso_path', 'none' or 'auto'.
See function 'explain_instance_with_data' in lime_base.py for
See function 'einterpret_instance_with_data' in lime_base.py for
details on what each of the options does.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
......@@ -373,18 +387,18 @@ class LimeImageExplainer(object):
self.feature_selection = feature_selection
self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
def explain_instance(self, image, classifier_fn, labels=(1,),
hide_color=None,
num_features=100000, num_samples=1000,
batch_size=10,
distance_metric='cosine',
model_regressor=None
):
"""Generates explanations for a prediction.
def interpret_instance(self, image, classifier_fn, labels=(1,),
hide_color=None,
num_features=100000, num_samples=1000,
batch_size=10,
distance_metric='cosine',
model_regressor=None
):
"""Generates interpretations for a prediction.
First, we generate neighborhood data by randomly perturbing features
from the instance (see __data_inverse). We then learn locally weighted
linear models on this neighborhood data to explain each of the classes
linear models on this neighborhood data to interpret each of the classes
in an interpretable way (see lime_base.py).
Args:
......@@ -393,19 +407,19 @@ class LimeImageExplainer(object):
classifier_fn: classifier prediction probability function, which
takes a numpy array and outputs prediction probabilities. For
ScikitClassifiers , this is classifier.predict_proba.
labels: iterable with labels to be explained.
labels: iterable with labels to be interpreted.
hide_color: TODO
num_features: maximum number of features present in explanation
num_features: maximum number of features present in interpretation
num_samples: size of the neighborhood to learn the linear model
batch_size: TODO
distance_metric: the distance metric to use for weights.
model_regressor: sklearn regressor to use in explanation. Defaults
model_regressor: sklearn regressor to use in interpretation. Defaults
to Ridge regression in LimeBase. Must have model_regressor.coef_
and 'sample_weight' as a parameter to model_regressor.fit()
Returns:
An ImageExplanation object (see lime_image.py) with the corresponding
explanations.
An ImageIinterpretation object (see lime_image.py) with the corresponding
interpretations.
"""
if len(image.shape) == 2:
image = gray2rgb(image)
......@@ -455,15 +469,15 @@ class LimeImageExplainer(object):
metric=distance_metric
).ravel()
ret_exp = ImageExplanation(image, segments)
interpretation_image = ImageInterpretation(image, segments)
for label in top:
(ret_exp.intercept[label],
ret_exp.local_exp[label],
ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
(interpretation_image.intercept[label],
interpretation_image.local_weights[label],
interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data(
data, labels, distances, label, num_features,
model_regressor=model_regressor,
feature_selection=self.feature_selection)
return ret_exp
return interpretation_image
def data_labels(self,
image,
......
......@@ -87,11 +87,11 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_
return compute_normlime_weights(fname_list, save_dir, num_samples)
def save_one_lime_predict_and_kmean_labels(lime_exp_all_weights, image_pred_labels, cluster_labels, save_path):
def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
lime_weights = {}
for label in image_pred_labels:
lime_weights[label] = lime_exp_all_weights[label]
lime_weights[label] = lime_all_weights[label]
for_normlime_weights = {
'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight)
......@@ -145,15 +145,15 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
pred_label = pred_label[:top_k]
algo = lime_base.LimeImageExplainer()
explainer = algo.explain_instance(image_show[0], predict_fn, pred_label, 0,
algo = lime_base.LimeImageInterpreter()
interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
num_samples=num_samples, batch_size=batch_size)
cluster_labels = kmeans_model.predict(
get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), explainer.segments)
get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
)
save_one_lime_predict_and_kmean_labels(
explainer.local_exp, pred_label,
interpreter.local_weights, pred_label,
cluster_labels,
save_path
)
......
......@@ -17,19 +17,19 @@ import cv2
import copy
import os.path as osp
import numpy as np
from .core.explanation import Explanation
from .core.interpretation import Interpretation
from .core.normlime_base import precompute_normlime_weights
def visualize(img_file,
model,
dataset=None,
explanation_type='lime',
algo='lime',
num_samples=3000,
batch_size=50,
save_dir='./'):
if model.status != 'Normal':
raise Exception('The explanation only can deal with the Normal model')
raise Exception('The interpretation only can deal with the Normal model')
model.arrange_transforms(
transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms)
......@@ -37,48 +37,48 @@ def visualize(img_file,
img = tmp_transforms(img_file)[0]
img = np.around(img).astype('uint8')
img = np.expand_dims(img, axis=0)
explaier = None
if explanation_type == 'lime':
explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
elif explanation_type == 'normlime':
interpreter = None
if algo == 'lime':
interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
elif algo == 'normlime':
if dataset is None:
raise Exception('The dataset is None. Cannot implement this kind of explanation')
explaier = get_normlime_explaier(img, model, dataset,
raise Exception('The dataset is None. Cannot implement this kind of interpretation')
interpreter = get_normlime_interpreter(img, model, dataset,
num_samples=num_samples, batch_size=batch_size,
save_dir=save_dir)
else:
raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
raise Exception('The {} interpretation method is not supported yet!'.format(algo))
img_name = osp.splitext(osp.split(img_file)[-1])[0]
explaier.explain(img, save_dir=save_dir)
interpreter.interpret(img, save_dir=save_dir)
def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
def predict_func(image):
image = image.astype('float32')
for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
labels_name = None
if dataset is not None:
labels_name = dataset.labels
explaier = Explanation('lime',
interpreter = Interpretation('lime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size)
return explaier
return interpreter
def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image):
image = image.astype('float32')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
def predict_func(image):
......@@ -87,7 +87,7 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
out = model.interpretation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
labels_name = None
......@@ -105,13 +105,13 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
num_samples=num_samples,
batch_size=batch_size,
save_dir=save_dir)
explaier = Explanation('normlime',
interpreter = Interpretation('normlime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size,
normlime_weights=npy_dir)
return explaier
return interpreter
def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
......
......@@ -13,6 +13,6 @@
# limitations under the License.
from __future__ import absolute_import
from .cv.models.explanation import visualize
from .cv.models.interpret import visualize
visualize = visualize.visualize
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册