提交 857300d4 编写于 作者: S seven

interpret: update normlime

上级 f001960b
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddlex.interpret.as_data_reader.readers import preprocess_image from paddlex.interpret.as_data_reader.readers import preprocess_image
def gen_user_home(): def gen_user_home():
if "HOME" in os.environ: if "HOME" in os.environ:
home_path = os.environ["HOME"] home_path = os.environ["HOME"]
...@@ -34,10 +35,20 @@ def paddle_get_fc_weights(var_name="fc_0.w_0"): ...@@ -34,10 +35,20 @@ def paddle_get_fc_weights(var_name="fc_0.w_0"):
def paddle_resize(extracted_features, outsize): def paddle_resize(extracted_features, outsize):
resized_features = fluid.layers.resize_bilinear(extracted_features, outsize) resized_features = fluid.layers.resize_bilinear(extracted_features,
outsize)
return resized_features return resized_features
def get_precomputed_normlime_weights():
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
normlime_weights_file = osp.join(
h_pre_models, "normlime_weights_imagenet_resnet50vc.npy")
return np.load(normlime_weights_file, allow_pickle=True).item()
def compute_features_for_kmeans(data_content): def compute_features_for_kmeans(data_content):
root_path = gen_user_home() root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex') root_path = osp.join(root_path, '.paddlex')
...@@ -47,6 +58,7 @@ def compute_features_for_kmeans(data_content): ...@@ -47,6 +58,7 @@ def compute_features_for_kmeans(data_content):
os.makedirs(root_path) os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path) pdx.utils.download_and_decompress(url, path=root_path)
def conv_bn_layer(input, def conv_bn_layer(input,
num_filters, num_filters,
filter_size, filter_size,
...@@ -55,7 +67,7 @@ def compute_features_for_kmeans(data_content): ...@@ -55,7 +67,7 @@ def compute_features_for_kmeans(data_content):
act=None, act=None,
name=None, name=None,
is_test=True, is_test=True,
global_name=''): global_name='for_kmeans_'):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
...@@ -79,14 +91,14 @@ def compute_features_for_kmeans(data_content): ...@@ -79,14 +91,14 @@ def compute_features_for_kmeans(data_content):
bias_attr=ParamAttr(global_name + bn_name + '_offset'), bias_attr=ParamAttr(global_name + bn_name + '_offset'),
moving_mean_name=global_name + bn_name + '_mean', moving_mean_name=global_name + bn_name + '_mean',
moving_variance_name=global_name + bn_name + '_variance', moving_variance_name=global_name + bn_name + '_variance',
use_global_stats=is_test use_global_stats=is_test)
)
startup_prog = fluid.default_startup_program().clone(for_test=True) startup_prog = fluid.default_startup_program().clone(for_test=True)
prog = fluid.Program() prog = fluid.Program()
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
image_op = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32') image_op = fluid.data(
name='image', shape=[None, 3, 224, 224], dtype='float32')
conv = conv_bn_layer( conv = conv_bn_layer(
input=image_op, input=image_op,
...@@ -110,7 +122,8 @@ def compute_features_for_kmeans(data_content): ...@@ -110,7 +122,8 @@ def compute_features_for_kmeans(data_content):
act='relu', act='relu',
name='conv1_3') name='conv1_3')
extracted_features = conv extracted_features = conv
resized_features = fluid.layers.resize_bilinear(extracted_features, image_op.shape[2:]) resized_features = fluid.layers.resize_bilinear(extracted_features,
image_op.shape[2:])
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) place = fluid.CUDAPlace(gpu_id)
...@@ -119,7 +132,10 @@ def compute_features_for_kmeans(data_content): ...@@ -119,7 +132,10 @@ def compute_features_for_kmeans(data_content):
exe.run(startup_prog) exe.run(startup_prog)
fluid.io.load_persistables(exe, h_pre_models, prog) fluid.io.load_persistables(exe, h_pre_models, prog)
images = preprocess_image(data_content) # transpose to [N, 3, H, W], scaled to [0.0, 1.0] images = preprocess_image(
result = exe.run(prog, fetch_list=[resized_features], feed={'image': images}) data_content) # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
result = exe.run(prog,
fetch_list=[resized_features],
feed={'image': images})
return result[0][0] return result[0][0]
...@@ -20,12 +20,10 @@ class Interpretation(object): ...@@ -20,12 +20,10 @@ class Interpretation(object):
""" """
Base class for all interpretation algorithms. Base class for all interpretation algorithms.
""" """
def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
supported_algorithms = { def __init__(self, interpretation_algorithm_name, predict_fn, label_names,
'cam': CAM, **kwargs):
'lime': LIME, supported_algorithms = {'cam': CAM, 'lime': LIME, 'normlime': NormLIME}
'normlime': NormLIME
}
self.algorithm_name = interpretation_algorithm_name.lower() self.algorithm_name = interpretation_algorithm_name.lower()
assert self.algorithm_name in supported_algorithms.keys() assert self.algorithm_name in supported_algorithms.keys()
...@@ -33,10 +31,13 @@ class Interpretation(object): ...@@ -33,10 +31,13 @@ class Interpretation(object):
# initialization for the interpretation algorithm. # initialization for the interpretation algorithm.
self.algorithm = supported_algorithms[self.algorithm_name]( self.algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, label_names, **kwargs self.predict_fn, label_names, **kwargs)
)
def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'): def interpret(self,
data_,
visualization=True,
save_to_disk=True,
save_dir='./tmp'):
""" """
Args: Args:
...@@ -48,4 +49,5 @@ class Interpretation(object): ...@@ -48,4 +49,5 @@ class Interpretation(object):
Returns: Returns:
""" """
return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir) return self.algorithm.interpret(data_, visualization, save_to_disk,
save_dir)
...@@ -27,7 +27,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -27,7 +27,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime. The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
""" """
import numpy as np import numpy as np
import scipy as sp import scipy as sp
...@@ -39,10 +38,8 @@ import paddlex.utils.logging as logging ...@@ -39,10 +38,8 @@ import paddlex.utils.logging as logging
class LimeBase(object): class LimeBase(object):
"""Class for learning a locally linear sparse model from perturbed data""" """Class for learning a locally linear sparse model from perturbed data"""
def __init__(self,
kernel_fn, def __init__(self, kernel_fn, verbose=False, random_state=None):
verbose=False,
random_state=None):
"""Init function """Init function
Args: Args:
...@@ -72,15 +69,14 @@ class LimeBase(object): ...@@ -72,15 +69,14 @@ class LimeBase(object):
""" """
from sklearn.linear_model import lars_path from sklearn.linear_model import lars_path
x_vector = weighted_data x_vector = weighted_data
alphas, _, coefs = lars_path(x_vector, alphas, _, coefs = lars_path(
weighted_labels, x_vector, weighted_labels, method='lasso', verbose=False)
method='lasso',
verbose=False)
return alphas, coefs return alphas, coefs
def forward_selection(self, data, labels, weights, num_features): def forward_selection(self, data, labels, weights, num_features):
"""Iteratively adds features to the model""" """Iteratively adds features to the model"""
clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) clf = Ridge(
alpha=0, fit_intercept=True, random_state=self.random_state)
used_features = [] used_features = []
for _ in range(min(num_features, data.shape[1])): for _ in range(min(num_features, data.shape[1])):
max_ = -100000000 max_ = -100000000
...@@ -88,11 +84,13 @@ class LimeBase(object): ...@@ -88,11 +84,13 @@ class LimeBase(object):
for feature in range(data.shape[1]): for feature in range(data.shape[1]):
if feature in used_features: if feature in used_features:
continue continue
clf.fit(data[:, used_features + [feature]], labels, clf.fit(data[:, used_features + [feature]],
labels,
sample_weight=weights) sample_weight=weights)
score = clf.score(data[:, used_features + [feature]], score = clf.score(
labels, data[:, used_features + [feature]],
sample_weight=weights) labels,
sample_weight=weights)
if score > max_: if score > max_:
best = feature best = feature
max_ = score max_ = score
...@@ -108,8 +106,8 @@ class LimeBase(object): ...@@ -108,8 +106,8 @@ class LimeBase(object):
elif method == 'forward_selection': elif method == 'forward_selection':
return self.forward_selection(data, labels, weights, num_features) return self.forward_selection(data, labels, weights, num_features)
elif method == 'highest_weights': elif method == 'highest_weights':
clf = Ridge(alpha=0.01, fit_intercept=True, clf = Ridge(
random_state=self.random_state) alpha=0.01, fit_intercept=True, random_state=self.random_state)
clf.fit(data, labels, sample_weight=weights) clf.fit(data, labels, sample_weight=weights)
coef = clf.coef_ coef = clf.coef_
...@@ -125,7 +123,8 @@ class LimeBase(object): ...@@ -125,7 +123,8 @@ class LimeBase(object):
nnz_indexes = argsort_data[::-1] nnz_indexes = argsort_data[::-1]
indices = weighted_data.indices[nnz_indexes] indices = weighted_data.indices[nnz_indexes]
num_to_pad = num_features - sdata num_to_pad = num_features - sdata
indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype))) indices = np.concatenate((indices, np.zeros(
num_to_pad, dtype=indices.dtype)))
indices_set = set(indices) indices_set = set(indices)
pad_counter = 0 pad_counter = 0
for i in range(data.shape[1]): for i in range(data.shape[1]):
...@@ -135,7 +134,8 @@ class LimeBase(object): ...@@ -135,7 +134,8 @@ class LimeBase(object):
if pad_counter >= num_to_pad: if pad_counter >= num_to_pad:
break break
else: else:
nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] nnz_indexes = argsort_data[sdata - num_features:sdata][::
-1]
indices = weighted_data.indices[nnz_indexes] indices = weighted_data.indices[nnz_indexes]
return indices return indices
else: else:
...@@ -146,13 +146,13 @@ class LimeBase(object): ...@@ -146,13 +146,13 @@ class LimeBase(object):
reverse=True) reverse=True)
return np.array([x[0] for x in feature_weights[:num_features]]) return np.array([x[0] for x in feature_weights[:num_features]])
elif method == 'lasso_path': elif method == 'lasso_path':
weighted_data = ((data - np.average(data, axis=0, weights=weights)) weighted_data = ((data - np.average(
* np.sqrt(weights[:, np.newaxis])) data, axis=0, weights=weights)) *
weighted_labels = ((labels - np.average(labels, weights=weights)) np.sqrt(weights[:, np.newaxis]))
* np.sqrt(weights)) weighted_labels = ((labels - np.average(
labels, weights=weights)) * np.sqrt(weights))
nonzero = range(weighted_data.shape[1]) nonzero = range(weighted_data.shape[1])
_, coefs = self.generate_lars_path(weighted_data, _, coefs = self.generate_lars_path(weighted_data, weighted_labels)
weighted_labels)
for i in range(len(coefs.T) - 1, 0, -1): for i in range(len(coefs.T) - 1, 0, -1):
nonzero = coefs.T[i].nonzero()[0] nonzero = coefs.T[i].nonzero()[0]
if len(nonzero) <= num_features: if len(nonzero) <= num_features:
...@@ -164,8 +164,8 @@ class LimeBase(object): ...@@ -164,8 +164,8 @@ class LimeBase(object):
n_method = 'forward_selection' n_method = 'forward_selection'
else: else:
n_method = 'highest_weights' n_method = 'highest_weights'
return self.feature_selection(data, labels, weights, return self.feature_selection(data, labels, weights, num_features,
num_features, n_method) n_method)
def interpret_instance_with_data(self, def interpret_instance_with_data(self,
neighborhood_data, neighborhood_data,
...@@ -214,30 +214,31 @@ class LimeBase(object): ...@@ -214,30 +214,31 @@ class LimeBase(object):
weights = self.kernel_fn(distances) weights = self.kernel_fn(distances)
labels_column = neighborhood_labels[:, label] labels_column = neighborhood_labels[:, label]
used_features = self.feature_selection(neighborhood_data, used_features = self.feature_selection(neighborhood_data,
labels_column, labels_column, weights,
weights, num_features, feature_selection)
num_features,
feature_selection)
if model_regressor is None: if model_regressor is None:
model_regressor = Ridge(alpha=1, fit_intercept=True, model_regressor = Ridge(
random_state=self.random_state) alpha=1, fit_intercept=True, random_state=self.random_state)
easy_model = model_regressor easy_model = model_regressor
easy_model.fit(neighborhood_data[:, used_features], easy_model.fit(neighborhood_data[:, used_features],
labels_column, sample_weight=weights) labels_column,
sample_weight=weights)
prediction_score = easy_model.score( prediction_score = easy_model.score(
neighborhood_data[:, used_features], neighborhood_data[:, used_features],
labels_column, sample_weight=weights) labels_column,
sample_weight=weights)
local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) local_pred = easy_model.predict(neighborhood_data[0, used_features]
.reshape(1, -1))
if self.verbose: if self.verbose:
logging.info('Intercept' + str(easy_model.intercept_)) logging.info('Intercept' + str(easy_model.intercept_))
logging.info('Prediction_local' + str(local_pred)) logging.info('Prediction_local' + str(local_pred))
logging.info('Right:' + str(neighborhood_labels[0, label])) logging.info('Right:' + str(neighborhood_labels[0, label]))
return (easy_model.intercept_, return (easy_model.intercept_, sorted(
sorted(zip(used_features, easy_model.coef_), zip(used_features, easy_model.coef_),
key=lambda x: np.abs(x[1]), reverse=True), key=lambda x: np.abs(x[1]),
prediction_score, local_pred) reverse=True), prediction_score, local_pred)
class ImageInterpretation(object): class ImageInterpretation(object):
...@@ -254,8 +255,13 @@ class ImageInterpretation(object): ...@@ -254,8 +255,13 @@ class ImageInterpretation(object):
self.local_weights = {} self.local_weights = {}
self.local_pred = None self.local_pred = None
def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, def get_image_and_mask(self,
num_features=5, min_weight=0.): label,
positive_only=True,
negative_only=False,
hide_rest=False,
num_features=5,
min_weight=0.):
"""Init function. """Init function.
Args: Args:
...@@ -279,7 +285,9 @@ class ImageInterpretation(object): ...@@ -279,7 +285,9 @@ class ImageInterpretation(object):
if label not in self.local_weights: if label not in self.local_weights:
raise KeyError('Label not in interpretation') raise KeyError('Label not in interpretation')
if positive_only & negative_only: if positive_only & negative_only:
raise ValueError("Positive_only and negative_only cannot be true at the same time.") raise ValueError(
"Positive_only and negative_only cannot be true at the same time."
)
segments = self.segments segments = self.segments
image = self.image image = self.image
local_weights_label = self.local_weights[label] local_weights_label = self.local_weights[label]
...@@ -289,14 +297,20 @@ class ImageInterpretation(object): ...@@ -289,14 +297,20 @@ class ImageInterpretation(object):
else: else:
temp = self.image.copy() temp = self.image.copy()
if positive_only: if positive_only:
fs = [x[0] for x in local_weights_label fs = [
if x[1] > 0 and x[1] > min_weight][:num_features] x[0] for x in local_weights_label
if x[1] > 0 and x[1] > min_weight
][:num_features]
if negative_only: if negative_only:
fs = [x[0] for x in local_weights_label fs = [
if x[1] < 0 and abs(x[1]) > min_weight][:num_features] x[0] for x in local_weights_label
if x[1] < 0 and abs(x[1]) > min_weight
][:num_features]
if positive_only or negative_only: if positive_only or negative_only:
c = 1 if positive_only else 0
for f in fs: for f in fs:
temp[segments == f] = image[segments == f].copy() temp[segments == f] = [0, 255, 0]
# temp[segments == f, c] = np.max(image)
mask[segments == f] = 1 mask[segments == f] = 1
return temp, mask return temp, mask
else: else:
...@@ -330,8 +344,11 @@ class ImageInterpretation(object): ...@@ -330,8 +344,11 @@ class ImageInterpretation(object):
temp = np.zeros_like(image) temp = np.zeros_like(image)
weight_max = abs(local_weights_label[0][1]) weight_max = abs(local_weights_label[0][1])
local_weights_label = [(f, w/weight_max) for f, w in local_weights_label] local_weights_label = [(f, w / weight_max)
local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True) # negatives are at last. for f, w in local_weights_label]
local_weights_label = sorted(
local_weights_label, key=lambda x: x[1],
reverse=True) # negatives are at last.
cmaps = cm.get_cmap('Spectral') cmaps = cm.get_cmap('Spectral')
colors = cmaps(np.linspace(0, 1, len(local_weights_label))) colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
...@@ -354,8 +371,12 @@ class LimeImageInterpreter(object): ...@@ -354,8 +371,12 @@ class LimeImageInterpreter(object):
feature that is 1 when the value is the same as the instance being feature that is 1 when the value is the same as the instance being
interpreted.""" interpreted."""
def __init__(self, kernel_width=.25, kernel=None, verbose=False, def __init__(self,
feature_selection='auto', random_state=None): kernel_width=.25,
kernel=None,
verbose=False,
feature_selection='auto',
random_state=None):
"""Init function. """Init function.
Args: Args:
...@@ -377,22 +398,27 @@ class LimeImageInterpreter(object): ...@@ -377,22 +398,27 @@ class LimeImageInterpreter(object):
kernel_width = float(kernel_width) kernel_width = float(kernel_width)
if kernel is None: if kernel is None:
def kernel(d, kernel_width): def kernel(d, kernel_width):
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) return np.sqrt(np.exp(-(d**2) / kernel_width**2))
kernel_fn = partial(kernel, kernel_width=kernel_width) kernel_fn = partial(kernel, kernel_width=kernel_width)
self.random_state = check_random_state(random_state) self.random_state = check_random_state(random_state)
self.feature_selection = feature_selection self.feature_selection = feature_selection
self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state) self.base = LimeBase(
kernel_fn, verbose, random_state=self.random_state)
def interpret_instance(self, image, classifier_fn, labels=(1,), def interpret_instance(self,
image,
classifier_fn,
labels=(1, ),
hide_color=None, hide_color=None,
num_features=100000, num_samples=1000, num_features=100000,
num_samples=1000,
batch_size=10, batch_size=10,
distance_metric='cosine', distance_metric='cosine',
model_regressor=None model_regressor=None):
):
"""Generates interpretations for a prediction. """Generates interpretations for a prediction.
First, we generate neighborhood data by randomly perturbing features First, we generate neighborhood data by randomly perturbing features
...@@ -435,6 +461,7 @@ class LimeImageInterpreter(object): ...@@ -435,6 +461,7 @@ class LimeImageInterpreter(object):
self.segments = segments self.segments = segments
fudged_image = image.copy() fudged_image = image.copy()
# global_mean = np.mean(image, (0, 1))
if hide_color is None: if hide_color is None:
# if no hide_color, use the mean # if no hide_color, use the mean
for x in np.unique(segments): for x in np.unique(segments):
...@@ -461,24 +488,30 @@ class LimeImageInterpreter(object): ...@@ -461,24 +488,30 @@ class LimeImageInterpreter(object):
top = labels top = labels
data, labels = self.data_labels(image, fudged_image, segments, data, labels = self.data_labels(
classifier_fn, num_samples, image,
batch_size=batch_size) fudged_image,
segments,
classifier_fn,
num_samples,
batch_size=batch_size)
distances = sklearn.metrics.pairwise_distances( distances = sklearn.metrics.pairwise_distances(
data, data, data[0].reshape(1, -1), metric=distance_metric).ravel()
data[0].reshape(1, -1),
metric=distance_metric
).ravel()
interpretation_image = ImageInterpretation(image, segments) interpretation_image = ImageInterpretation(image, segments)
for label in top: for label in top:
(interpretation_image.intercept[label], (interpretation_image.intercept[label],
interpretation_image.local_weights[label], interpretation_image.local_weights[label],
interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data( interpretation_image.score, interpretation_image.local_pred
data, labels, distances, label, num_features, ) = self.base.interpret_instance_with_data(
model_regressor=model_regressor, data,
feature_selection=self.feature_selection) labels,
distances,
label,
num_features,
model_regressor=model_regressor,
feature_selection=self.feature_selection)
return interpretation_image return interpretation_image
def data_labels(self, def data_labels(self,
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import glob import glob
import tqdm
from paddlex.interpret.as_data_reader.readers import read_image from paddlex.interpret.as_data_reader.readers import read_image
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
...@@ -38,18 +39,24 @@ def combine_normlime_and_lime(lime_weights, g_weights): ...@@ -38,18 +39,24 @@ def combine_normlime_and_lime(lime_weights, g_weights):
for y in pred_labels: for y in pred_labels:
normlized_lime_weights_y = lime_weights[y] normlized_lime_weights_y = lime_weights[y]
lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y} lime_weights_dict = {
tuple_w[0]: tuple_w[1]
for tuple_w in normlized_lime_weights_y
}
normlized_g_weight_y = g_weights[y] normlized_g_weight_y = g_weights[y]
normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y} normlime_weights_dict = {
tuple_w[0]: tuple_w[1]
for tuple_w in normlized_g_weight_y
}
combined_weights[y] = [ combined_weights[y] = [
(seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k]) (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
for seg_k in lime_weights_dict.keys() for seg_k in lime_weights_dict.keys()
] ]
combined_weights[y] = sorted(combined_weights[y], combined_weights[y] = sorted(
key=lambda x: np.abs(x[1]), reverse=True) combined_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
return combined_weights return combined_weights
...@@ -67,7 +74,8 @@ def centroid_using_superpixels(features, segments): ...@@ -67,7 +74,8 @@ def centroid_using_superpixels(features, segments):
regions = regionprops(segments + 1) regions = regionprops(segments + 1)
one_list = np.zeros((len(np.unique(segments)), features.shape[2])) one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
for i, r in enumerate(regions): for i, r in enumerate(regions):
one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :] one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] +
0.5), :]
return one_list return one_list
...@@ -80,30 +88,39 @@ def get_feature_for_kmeans(feature_map, segments): ...@@ -80,30 +88,39 @@ def get_feature_for_kmeans(feature_map, segments):
return x return x
def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'): def precompute_normlime_weights(list_data_,
predict_fn,
num_samples=3000,
batch_size=50,
save_dir='./tmp'):
# save lime weights and kmeans cluster labels # save lime weights and kmeans cluster labels
precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir) precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
save_dir)
# load precomputed results, compute normlime weights and save. # load precomputed results, compute normlime weights and save.
fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples))) fname_list = glob.glob(
os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
return compute_normlime_weights(fname_list, save_dir, num_samples) return compute_normlime_weights(fname_list, save_dir, num_samples)
def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path): def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels,
cluster_labels, save_path):
lime_weights = {} lime_weights = {}
for label in image_pred_labels: for label in image_pred_labels:
lime_weights[label] = lime_all_weights[label] lime_weights[label] = lime_all_weights[label]
for_normlime_weights = { for_normlime_weights = {
'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight) 'lime_weights':
lime_weights, # a dict: class_label: (seg_label, weight)
'cluster': cluster_labels # a list with segments as indices. 'cluster': cluster_labels # a list with segments as indices.
} }
np.save(save_path, for_normlime_weights) np.save(save_path, for_normlime_weights)
def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir): def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
save_dir):
root_path = gen_user_home() root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex') root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models") h_pre_models = osp.join(root_path, "pre_models")
...@@ -117,17 +134,24 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -117,17 +134,24 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
for data_index, each_data_ in enumerate(list_data_): for data_index, each_data_ in enumerate(list_data_):
if isinstance(each_data_, str): if isinstance(each_data_, str):
save_path = "lime_weights_s{}_{}.npy".format(num_samples, each_data_.split('/')[-1].split('.')[0]) save_path = "lime_weights_s{}_{}.npy".format(
num_samples, each_data_.split('/')[-1].split('.')[0])
save_path = os.path.join(save_dir, save_path) save_path = os.path.join(save_dir, save_path)
else: else:
save_path = "lime_weights_s{}_{}.npy".format(num_samples, data_index) save_path = "lime_weights_s{}_{}.npy".format(num_samples,
data_index)
save_path = os.path.join(save_dir, save_path) save_path = os.path.join(save_dir, save_path)
if os.path.exists(save_path): if os.path.exists(save_path):
logging.info(save_path + ' exists, not computing this one.', use_color=True) logging.info(
save_path + ' exists, not computing this one.', use_color=True)
continue continue
img_file_name = each_data_ if isinstance(each_data_, str) else data_index img_file_name = each_data_ if isinstance(each_data_,
logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True) str) else data_index
logging.info(
'processing ' + img_file_name + ' [{}/{}]'.format(data_index,
len(list_data_)),
use_color=True)
image_show = read_image(each_data_) image_show = read_image(each_data_)
result = predict_fn(image_show) result = predict_fn(image_show)
...@@ -156,32 +180,38 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -156,32 +180,38 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
pred_label = pred_label[:top_k] pred_label = pred_label[:top_k]
algo = lime_base.LimeImageInterpreter() algo = lime_base.LimeImageInterpreter()
interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0, interpreter = algo.interpret_instance(
num_samples=num_samples, batch_size=batch_size) image_show[0],
predict_fn,
X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments) pred_label,
0,
num_samples=num_samples,
batch_size=batch_size)
X = get_feature_for_kmeans(
compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
interpreter.segments)
try: try:
cluster_labels = kmeans_model.predict(X) cluster_labels = kmeans_model.predict(X)
except AttributeError: except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_) cluster_labels, _ = pairwise_distances_argmin_min(
X, kmeans_model.cluster_centers_)
save_one_lime_predict_and_kmean_labels( save_one_lime_predict_and_kmean_labels(
interpreter.local_weights, pred_label, interpreter.local_weights, pred_label, cluster_labels, save_path)
cluster_labels,
save_path
)
def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
normlime_weights_all_labels = {} normlime_weights_all_labels = {}
for f in a_list_lime_fnames: for f in a_list_lime_fnames:
try: try:
lime_weights_and_cluster = np.load(f, allow_pickle=True).item() lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
lime_weights = lime_weights_and_cluster['lime_weights'] lime_weights = lime_weights_and_cluster['lime_weights']
cluster = lime_weights_and_cluster['cluster'] cluster = lime_weights_and_cluster['cluster']
except: except:
logging.info('When loading precomputed LIME result, skipping' + str(f)) logging.info('When loading precomputed LIME result, skipping' +
str(f))
continue continue
logging.info('Loading precomputed LIME result,' + str(f)) logging.info('Loading precomputed LIME result,' + str(f))
pred_labels = lime_weights.keys() pred_labels = lime_weights.keys()
...@@ -203,10 +233,12 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): ...@@ -203,10 +233,12 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
for y in normlime_weights_all_labels: for y in normlime_weights_all_labels:
normlime_weights = normlime_weights_all_labels.get(y, {}) normlime_weights = normlime_weights_all_labels.get(y, {})
for k in normlime_weights: for k in normlime_weights:
normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k]) normlime_weights[k] = sum(normlime_weights[k]) / len(
normlime_weights[k])
# check normlime # check normlime
if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1: if len(normlime_weights_all_labels.keys()) < max(
normlime_weights_all_labels.keys()) + 1:
logging.info( logging.info(
"\n" + \ "\n" + \
"Warning: !!! \n" + \ "Warning: !!! \n" + \
...@@ -218,17 +250,102 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): ...@@ -218,17 +250,102 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
) )
n = 0 n = 0
f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n) f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
while os.path.exists( lime_num_samples, len(a_list_lime_fnames), n)
os.path.join(save_dir, f_out) while os.path.exists(os.path.join(save_dir, f_out)):
):
n += 1 n += 1
f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n) f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
lime_num_samples, len(a_list_lime_fnames), n)
continue continue
np.save( np.save(os.path.join(save_dir, f_out), normlime_weights_all_labels)
os.path.join(save_dir, f_out),
normlime_weights_all_labels
)
return os.path.join(save_dir, f_out) return os.path.join(save_dir, f_out)
def precompute_global_classifier(dataset,
predict_fn,
save_path,
batch_size=50,
max_num_samples=1000):
from sklearn.linear_model import LogisticRegression
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
kmeans_model = load_kmeans_model(h_pre_models_kmeans)
image_list = []
for item in dataset.file_list:
image_list.append(item[0])
x_data = []
y_labels = []
for each_data_ in tqdm.tqdm(image_list):
x_data_i = np.zeros((len(kmeans_model.cluster_centers_)))
image_show = read_image(each_data_)
result = predict_fn(image_show)
result = result[0] # only one image here.
c = compute_features_for_kmeans(image_show).transpose((1, 2, 0))
segments = np.zeros((image_show.shape[1], image_show.shape[2]),
np.int32)
num_blocks = 10
height_per_i = segments.shape[0] // num_blocks + 1
width_per_i = segments.shape[1] // num_blocks + 1
for i in range(segments.shape[0]):
for j in range(segments.shape[1]):
segments[i,
j] = i // height_per_i * num_blocks + j // width_per_i
# segments = quickshift(image_show[0], sigma=1)
X = get_feature_for_kmeans(c, segments)
try:
cluster_labels = kmeans_model.predict(X)
except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(
X, kmeans_model.cluster_centers_)
for c in cluster_labels:
x_data_i[c] = 1
# x_data_i /= len(cluster_labels)
pred_y_i = np.argmax(result)
y_labels.append(pred_y_i)
x_data.append(x_data_i)
clf = LogisticRegression(multi_class='multinomial', max_iter=1000)
clf.fit(x_data, y_labels)
num_classes = len(np.unique(y_labels))
normlime_weights_all_labels = {}
for class_index in range(num_classes):
w = clf.coef_[class_index]
# softmax
w = w - np.max(w)
exp_w = np.exp(w * 10)
w = exp_w / np.sum(exp_w)
normlime_weights_all_labels[class_index] = {
i: wi
for i, wi in enumerate(w)
}
logging.info("Saving the computed normlime_weights in {}".format(
save_path))
np.save(save_path, normlime_weights_all_labels)
return save_path
...@@ -13,17 +13,26 @@ ...@@ -13,17 +13,26 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import cv2
import copy
def interpretation_predict(model, images): def interpretation_predict(model, images):
model.arrange_transforms( images = images.astype('float32')
transforms=model.test_transforms, mode='test') model.arrange_transforms(transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
new_imgs = [] new_imgs = []
for i in range(images.shape[0]): for i in range(images.shape[0]):
img = images[i] images[i] = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR)
new_imgs.append(model.test_transforms(img)[0]) new_imgs.append(model.test_transforms(images[i])[0])
new_imgs = np.array(new_imgs) new_imgs = np.array(new_imgs)
result = model.exe.run( out = model.exe.run(model.test_prog,
model.test_prog, feed={'image': new_imgs},
feed={'image': new_imgs}, fetch_list=list(model.interpretation_feats.values()))
fetch_list=list(model.interpretation_feats.values()))
return result model.test_transforms.transforms = tmp_transforms
\ No newline at end of file
return out
...@@ -20,79 +20,79 @@ import numpy as np ...@@ -20,79 +20,79 @@ import numpy as np
import paddlex as pdx import paddlex as pdx
from .interpretation_predict import interpretation_predict from .interpretation_predict import interpretation_predict
from .core.interpretation import Interpretation from .core.interpretation import Interpretation
from .core.normlime_base import precompute_normlime_weights from .core.normlime_base import precompute_global_classifier
from .core._session_preparation import gen_user_home from .core._session_preparation import gen_user_home
def lime(img_file,
model, def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
num_samples=3000, """使用LIME算法将模型预测结果的可解释性可视化。
batch_size=50,
save_dir='./'):
"""使用LIME算法将模型预测结果的可解释性可视化。
LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心, LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系, 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
得到每个输入维度的权重,以此来解释模型。 得到每个输入维度的权重,以此来解释模型。
注意:LIME可解释性结果可视化目前只支持分类模型。 注意:LIME可解释性结果可视化目前只支持分类模型。
Args: Args:
img_file (str): 预测图像路径。 img_file (str): 预测图像路径。
model (paddlex.cv.models): paddlex中的模型。 model (paddlex.cv.models): paddlex中的模型。
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
batch_size (int): 预测数据batch大小,默认为50。 batch_size (int): 预测数据batch大小,默认为50。
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
""" """
assert model.model_type == 'classifier', \ assert model.model_type == 'classifier', \
'Now the interpretation visualize only be supported in classifier!' 'Now the interpretation visualize only be supported in classifier!'
if model.status != 'Normal': if model.status != 'Normal':
raise Exception('The interpretation only can deal with the Normal model') raise Exception(
'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
model.arrange_transforms( model.arrange_transforms(transforms=model.test_transforms, mode='test')
transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2] tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0] img = tmp_transforms(img_file)[0]
img = np.around(img).astype('uint8') img = np.around(img).astype('uint8')
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
interpreter = None interpreter = None
interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size) interpreter = get_lime_interpreter(
img, model, num_samples=num_samples, batch_size=batch_size)
img_name = osp.splitext(osp.split(img_file)[-1])[0] img_name = osp.splitext(osp.split(img_file)[-1])[0]
interpreter.interpret(img, save_dir=save_dir) interpreter.interpret(img, save_dir=save_dir)
def normlime(img_file, def normlime(img_file,
model, model,
dataset=None, dataset=None,
num_samples=3000, num_samples=3000,
batch_size=50, batch_size=50,
save_dir='./'): save_dir='./',
normlime_weights_file=None):
"""使用NormLIME算法将模型预测结果的可解释性可视化。 """使用NormLIME算法将模型预测结果的可解释性可视化。
NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
注意2:NormLIME可解释性结果可视化目前只支持分类模型。 注意2:NormLIME可解释性结果可视化目前只支持分类模型。
Args: Args:
img_file (str): 预测图像路径。 img_file (str): 预测图像路径。
model (paddlex.cv.models): paddlex中的模型。 model (paddlex.cv.models): paddlex中的模型。
dataset (paddlex.datasets): 数据集读取器,默认为None。 dataset (paddlex.datasets): 数据集读取器,默认为None。
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
batch_size (int): 预测数据batch大小,默认为50。 batch_size (int): 预测数据batch大小,默认为50。
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。
""" """
assert model.model_type == 'classifier', \ assert model.model_type == 'classifier', \
'Now the interpretation visualize only be supported in classifier!' 'Now the interpretation visualize only be supported in classifier!'
if model.status != 'Normal': if model.status != 'Normal':
raise Exception('The interpretation only can deal with the Normal model') raise Exception(
'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
model.arrange_transforms( model.arrange_transforms(transforms=model.test_transforms, mode='test')
transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2] tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0] img = tmp_transforms(img_file)[0]
...@@ -100,52 +100,48 @@ def normlime(img_file, ...@@ -100,52 +100,48 @@ def normlime(img_file,
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
interpreter = None interpreter = None
if dataset is None: if dataset is None:
raise Exception('The dataset is None. Cannot implement this kind of interpretation') raise Exception(
interpreter = get_normlime_interpreter(img, model, dataset, 'The dataset is None. Cannot implement this kind of interpretation')
num_samples=num_samples, batch_size=batch_size, interpreter = get_normlime_interpreter(
save_dir=save_dir) img,
model,
dataset,
num_samples=num_samples,
batch_size=batch_size,
save_dir=save_dir,
normlime_weights_file=normlime_weights_file)
img_name = osp.splitext(osp.split(img_file)[-1])[0] img_name = osp.splitext(osp.split(img_file)[-1])[0]
interpreter.interpret(img, save_dir=save_dir) interpreter.interpret(img, save_dir=save_dir)
def get_lime_interpreter(img, model, num_samples=3000, batch_size=50): def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
def predict_func(image): def predict_func(image):
image = image.astype('float32')
for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = interpretation_predict(model, image) out = interpretation_predict(model, image)
model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
if hasattr(model, 'labels'): if hasattr(model, 'labels'):
labels_name = model.labels labels_name = model.labels
interpreter = Interpretation('lime', interpreter = Interpretation(
predict_func, 'lime',
labels_name, predict_func,
num_samples=num_samples, labels_name,
batch_size=batch_size) num_samples=num_samples,
batch_size=batch_size)
return interpreter return interpreter
def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'): def get_normlime_interpreter(img,
def precompute_predict_func(image): model,
image = image.astype('float32') dataset,
tmp_transforms = copy.deepcopy(model.test_transforms.transforms) num_samples=3000,
model.test_transforms.transforms = model.test_transforms.transforms[-2:] batch_size=50,
out = interpretation_predict(model, image) save_dir='./',
model.test_transforms.transforms = tmp_transforms normlime_weights_file=None):
return out[0]
def predict_func(image): def predict_func(image):
image = image.astype('float32')
for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = interpretation_predict(model, image) out = interpretation_predict(model, image)
model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
if dataset is not None: if dataset is not None:
labels_name = dataset.labels labels_name = dataset.labels
...@@ -157,28 +153,29 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 ...@@ -157,28 +153,29 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
os.makedirs(root_path) os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path) pdx.utils.download_and_decompress(url, path=root_path)
npy_dir = precompute_for_normlime(precompute_predict_func,
dataset,
num_samples=num_samples,
batch_size=batch_size,
save_dir=save_dir)
interpreter = Interpretation('normlime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size,
normlime_weights=npy_dir)
return interpreter
def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'): if osp.exists(osp.join(save_dir, normlime_weights_file)):
image_list = [] normlime_weights_file = osp.join(save_dir, normlime_weights_file)
for item in dataset.file_list: try:
image_list.append(item[0]) np.load(normlime_weights_file, allow_pickle=True).item()
return precompute_normlime_weights( except:
image_list, normlime_weights_file = precompute_global_classifier(
dataset,
predict_func,
save_path=normlime_weights_file,
batch_size=batch_size)
else:
normlime_weights_file = precompute_global_classifier(
dataset,
predict_func, predict_func,
num_samples=num_samples, save_path=normlime_weights_file,
batch_size=batch_size, batch_size=batch_size)
save_dir=save_dir)
interpreter = Interpretation(
'normlime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size,
normlime_weights=normlime_weights_file)
return interpreter
...@@ -14,18 +14,33 @@ model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilene ...@@ -14,18 +14,33 @@ model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilene
pdx.utils.download_and_decompress(model_file, path='./') pdx.utils.download_and_decompress(model_file, path='./')
# 加载模型 # 加载模型
model = pdx.load_model('mini_imagenet_veg_mobilenetv2') model_file = 'mini_imagenet_veg_mobilenetv2'
model = pdx.load_model(model_file)
# 定义测试所用的数据集 # 定义测试所用的数据集
dataset = 'mini_imagenet_veg'
test_dataset = pdx.datasets.ImageNet( test_dataset = pdx.datasets.ImageNet(
data_dir='mini_imagenet_veg', data_dir=dataset,
file_list=osp.join('mini_imagenet_veg', 'test_list.txt'), file_list=osp.join(dataset, 'test_list.txt'),
label_list=osp.join('mini_imagenet_veg', 'labels.txt'), label_list=osp.join(dataset, 'labels.txt'),
transforms=model.test_transforms) transforms=model.test_transforms)
# 可解释性可视化 import numpy as np
pdx.interpret.normlime( np.random.seed(5)
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', perm = np.random.permutation(len(test_dataset.file_list))
model,
test_dataset, for i in range(len(test_dataset.file_list)):
save_dir='./')
# 可解释性可视化
pdx.interpret.normlime(
test_dataset.file_list[perm[i]][0],
model,
test_dataset,
save_dir='./',
normlime_weights_file='{}_{}.npy'.format(
dataset.split('/')[-1], model.model_name))
if i == 1:
# first iter will have an initialization process, followed by the interpretation.
# second iter will directly load the initialization process, followed by the interpretation.
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册