提交 011cff21 编写于 作者: S sunyanfang01

add lime

上级 2484756a
...@@ -27,7 +27,6 @@ from .base import BaseAPI ...@@ -27,7 +27,6 @@ from .base import BaseAPI
class BaseClassifier(BaseAPI): class BaseClassifier(BaseAPI):
"""构建分类器,并实现其训练、评估、预测和模型导出。 """构建分类器,并实现其训练、评估、预测和模型导出。
Args: Args:
model_name (str): 分类器的模型名字,取值范围为['ResNet18', model_name (str): 分类器的模型名字,取值范围为['ResNet18',
'ResNet34', 'ResNet50', 'ResNet101', 'ResNet34', 'ResNet50', 'ResNet101',
...@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI): ...@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI):
if mode != 'test': if mode != 'test':
label = fluid.data(dtype='int64', shape=[None, 1], name='label') label = fluid.data(dtype='int64', shape=[None, 1], name='label')
model = getattr(paddlex.cv.nets, str.lower(self.model_name)) model = getattr(paddlex.cv.nets, str.lower(self.model_name))
net_out = model(image, num_classes=self.num_classes) net_out, feat = model(image, num_classes=self.num_classes)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
inputs = OrderedDict([('image', image)]) inputs = OrderedDict([('image', image)])
outputs = OrderedDict([('predict', softmax_out)]) outputs = OrderedDict([('predict', softmax_out), ('net_out', feat[-1])])
if mode != 'test': if mode != 'test':
cost = fluid.layers.cross_entropy(input=softmax_out, label=label) cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
...@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI): ...@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience=5, early_stop_patience=5,
resume_checkpoint=None): resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
num_epochs (int): 训练迭代轮数。 num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。 train_dataset (paddlex.datasets): 训练数据读取器。
...@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI): ...@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
...@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI): ...@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI):
epoch_id=None, epoch_id=None,
return_details=False): return_details=False):
"""评估。 """评估。
Args: Args:
eval_dataset (paddlex.datasets): 验证数据读取器。 eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。 batch_size (int): 验证数据批大小。默认为1。
epoch_id (int): 当前评估模型所在的训练轮数。 epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。 return_details (bool): 是否返回详细信息。
Returns: Returns:
dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5', dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
分别表示最大值的accuracy、前5个最大值的accuracy。 分别表示最大值的accuracy、前5个最大值的accuracy。
...@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI): ...@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI):
def predict(self, img_file, transforms=None, topk=1): def predict(self, img_file, transforms=None, topk=1):
"""预测。 """预测。
Args: Args:
img_file (str): 预测图像路径。 img_file (str): 预测图像路径。
transforms (paddlex.cls.transforms): 数据预处理操作。 transforms (paddlex.cls.transforms): 数据预处理操作。
topk (int): 预测时前k个最大值。 topk (int): 预测时前k个最大值。
Returns: Returns:
list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score', list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
分别对应预测类别id、预测类别标签、预测得分。 分别对应预测类别id、预测类别标签、预测得分。
...@@ -280,6 +273,19 @@ class BaseClassifier(BaseAPI): ...@@ -280,6 +273,19 @@ class BaseClassifier(BaseAPI):
} for l in pred_label] } for l in pred_label]
return res return res
def explanation_predict(self, images):
self.arrange_transforms(
transforms=self.test_transforms, mode='test')
new_imgs = []
for i in range(images.shape[0]):
img = images[i]
new_imgs.append(self.test_transforms(img)[0])
new_imgs = np.array(new_imgs)
result = self.exe.run(
self.test_prog,
feed={'image': new_imgs},
fetch_list=list(self.test_outputs.values()))
return result[1:]
class ResNet18(BaseClassifier): class ResNet18(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
......
import os
def imagenet_val_files_and_labels(dataset_directory):
classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}
images_path = os.path.join(dataset_directory, 'val')
filenames = []
labels = []
lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
for i, line in enumerate(lines):
class_name = line.split('\n')[0]
a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
filenames.append(f'{images_path}/{a}')
labels.append(class_to_indx[class_name])
# print(filenames[-1], labels[-1])
return filenames, labels
def _find_classes(dir):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
\ No newline at end of file
import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import cv2
import numpy as np
import six
import glob
from as_data_reader.data_path_utils import _find_classes
from PIL import Image
def resize_short(img, target_size, interpolation=None):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
resized = cv2.resize(
img, (resized_width, resized_height), interpolation=interpolation)
else:
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(img, target_size, center=True):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height, width = img.shape[:2]
size = target_size
if center:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def preprocess_image(img, random_mirror=False):
"""
centered, scaled by 1/255.
:param img: np.array: shape: [ns, h, w, 3], color order: rgb.
:return: np.array: shape: [ns, h, w, 3]
"""
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# transpose to [ns, 3, h, w]
img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if random_mirror:
mirror = int(np.random.uniform(0, 2))
if mirror == 1:
img = img[:, :, ::-1, :]
return img
def read_image(img_path, target_size=256, crop_size=224):
"""
resize_short to 256, then center crop to 224.
:param img_path: one image path
:return: np.array: shape: [1, h, w, 3], color order: rgb.
"""
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
img = Image.open(f)
img = img.convert('RGB')
img = np.array(img)
# img = cv2.imread(img_path)
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, target_size=crop_size, center=True)
# img = img[:, :, ::-1]
img = np.expand_dims(img, axis=0)
return img
elif isinstance(img_path, np.ndarray):
assert len(img_path.shape) == 4
return img_path
else:
ValueError(f"Not recognized data type {type(img_path)}.")
class ReaderConfig(object):
"""
A generic data loader where the images are arranged in this way:
root/train/dog/xxy.jpg
root/train/dog/xxz.jpg
...
root/train/cat/nsdf3.jpg
root/train/cat/asd932_.jpg
...
root/test/dog/xxx.jpg
...
root/test/cat/123.jpg
...
"""
def __init__(self, dataset_dir, is_test):
image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test)
random_per = np.random.permutation(range(len(image_paths)))
self.image_paths = image_paths[random_per]
self.labels = labels[random_per]
self.is_test = is_test
def get_reader(self):
def reader():
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
target_size = 256
crop_size = 224
for i, img_path in enumerate(self.image_paths):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
img = cv2.imread(img_path)
if img is None:
print(img_path)
continue
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, crop_size, center=self.is_test)
img = img[:, :, ::-1]
img = np.expand_dims(img, axis=0)
img = preprocess_image(img, not self.is_test)
yield img, self.labels[i]
return reader
def get_dataset_info(self, dataset_dir, is_test=False):
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
# read
if is_test:
datasubset_dir = os.path.join(dataset_dir, 'test')
else:
datasubset_dir = os.path.join(dataset_dir, 'train')
class_names, class_to_idx = _find_classes(datasubset_dir)
# num_classes = len(class_names)
image_paths = []
labels = []
for class_name in class_names:
classes_dir = os.path.join(datasubset_dir, class_name)
for img_path in glob.glob(os.path.join(classes_dir, '*')):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
image_paths.append(img_path)
labels.append(class_to_idx[class_name])
image_paths = np.array(image_paths)
labels = np.array(labels)
return image_paths, labels, len(class_names)
def create_reader(list_image_path, list_label=None, is_test=False):
def reader():
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
target_size = 256
crop_size = 224
for i, img_path in enumerate(list_image_path):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
img = cv2.imread(img_path)
if img is None:
print(img_path)
continue
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, crop_size, center=is_test)
img = img[:, :, ::-1]
img_show = np.expand_dims(img, axis=0)
img = preprocess_image(img_show, not is_test)
label = 0 if list_label is None else list_label[i]
yield img_show, img, label
return reader
\ No newline at end of file
import os
import paddle.fluid as fluid
import numpy as np
def paddle_get_fc_weights(var_name="fc_0.w_0"):
fc_weights = fluid.global_scope().find_var(var_name).get_tensor()
return np.array(fc_weights)
def paddle_resize(extracted_features, outsize):
resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
return resized_features
\ No newline at end of file
from .explanation_algorithms import CAM, LIME, NormLIME
class Explanation(object):
"""
Base class for all explanation algorithms.
"""
def __init__(self, explanation_algorithm_name, predict_fn, **kwargs):
supported_algorithms = {
'cam': CAM,
'lime': LIME,
'normlime': NormLIME
}
self.algorithm_name = explanation_algorithm_name.lower()
assert self.algorithm_name in supported_algorithms.keys()
self.predict_fn = predict_fn
# initialization for the explanation algorithm.
self.explain_algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, **kwargs
)
def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
"""
Args:
data_: data_ can be a path or numpy.ndarray.
visualization: whether to show using matplotlib.
save_to_disk: whether to save the figure in local disk.
save_dir: dir to save figure if save_to_disk is True.
Returns:
"""
return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
import os
import numpy as np
import time
from . import lime_base
from ..as_data_reader.readers import read_image
from ._session_preparation import paddle_get_fc_weights
import cv2
class CAM(object):
def __init__(self, predict_fn):
"""
Args:
predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
output: [
logits [N, num_classes],
feature map before global average pooling [N, num_channels, h_, w_]
]
"""
self.predict_fn = predict_fn
def preparation_cam(self, data_path):
image_show = read_image(data_path)
result = self.predict_fn(image_show)
logit = result[0][0]
if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result)
else:
probability = logit
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
fc_weights = paddle_get_fc_weights()
feature_maps = result[1]
print('predicted result: ', pred_label[0], probability[pred_label[0]])
return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
feature_maps, fc_weights = self.preparation_cam(data_)
cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 1
ncols = 2
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(cam)
axes[1].set_title("CAM")
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'cam')
if visualization:
plt.show()
return
class LIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50):
"""
LIME wrapper. See lime_base.py for the detailed LIME implementation.
Args:
predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
num_samples: the number of samples that LIME takes for fitting.
batch_size: batch size for model inference each time.
"""
self.num_samples = num_samples
self.batch_size = batch_size
self.predict_fn = predict_fn
self.labels = None
self.image = None
self.lime_explainer = None
def preparation_lime(self, data_path):
image_show = read_image(data_path)
result = self.predict_fn(image_show)
result = result[0] # only one image here.
if abs(np.sum(result) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(result)
probability = exp_result / np.sum(exp_result)
else:
probability = result
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}')
end = time.time()
algo = lime_base.LimeImageExplainer()
explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0,
num_samples=self.num_samples, batch_size=self.batch_size)
self.lime_explainer = explainer
print('lime time: ', time.time() - end, 's.')
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.lime_explainer is None:
self.preparation_lime(data_)
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 2
weights_choices = [0.6, 0.75, 0.85]
ncols = len(weights_choices)
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w)
temp, mask = self.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'lime', self.num_samples)
if visualization:
plt.show()
return
class NormLIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50,
kmeans_model_for_normlime=None, normlime_weights=None):
assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model."
if normlime_weights is None:
raise NotImplementedError("Computing NormLIME weights is not implemented yet.")
self.num_samples = num_samples
self.batch_size = batch_size
self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
self.predict_fn = predict_fn
self.labels = None
self.image = None
def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
# global weights
g_weights = {y: [] for y in pred_labels}
for y in pred_labels:
cluster_weights_y = self.normlime_weights[y]
g_weights[y] = [
# some are not in the dict, 3000 samples may be not enough.
(i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
]
g_weights[y] = sorted(g_weights[y],
key=lambda x: np.abs(x[1]), reverse=True)
return g_weights
def preparation_normlime(self, data_path):
self._lime = LIME(
lambda images: self.predict_fn(images)[0],
self.num_samples,
self.batch_size
)
self._lime.preparation_lime(data_path)
image_show = read_image(data_path)
result = self.predict_fn(image_show)
logit = result[0][0] # only one image here.
if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result)
else:
probability = logit
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
print('predicted result: ', pred_label[0], probability[pred_label[0]])
local_feature_map = result[1][0]
cluster_labels = self.predict_cluster_labels(
local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments
)
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
return g_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_explainer.local_exp
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 4
weights_choices = [0.6, 0.85, 0.99]
ncols = len(weights_choices)
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
# NormLIME visualization
self._lime.lime_explainer.local_exp = g_weights
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels")
# NormLIME*LIME visualization
combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
self._lime.lime_explainer.local_exp = combined_weights
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels")
self._lime.lime_explainer.local_exp = lime_weights
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples)
if visualization:
plt.show()
def load_kmeans_model(fname):
import pickle
with open(fname, 'rb') as f:
kmeans_model = pickle.load(f)
return kmeans_model
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
segments = lime_explainer.segments
lime_weights = lime_explainer.local_exp[label]
num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
# l1 norm with filtered weights.
used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
a = 0.0
n = 0
for i, tuple_w in enumerate(normalized_weights):
if tuple_w[1] < 0:
continue
if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
continue
a += tuple_w[1]
if a > percentage_to_show:
n = i + 1
break
if n == 0:
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
return n
def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
_, nc, h, w = feature_maps.shape
cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
cam = cam.sum((0, 1))
if cam_min is None:
cam_min = np.min(cam)
if cam_max is None:
cam_max = np.max(cam)
cam = cam - cam_min
cam = cam / cam_max
cam = np.uint8(255 * cam)
cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap)
cam = heatmap + np.float32(image_show)
cam = cam / np.max(cam)
return cam
def avg_using_superpixels(features, segments):
one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
for x in np.unique(segments):
one_list[x] = np.mean(features[segments == x], axis=0)
return one_list
def centroid_using_superpixels(features, segments):
from skimage.measure import regionprops
regions = regionprops(segments + 1)
one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
for i, r in enumerate(regions):
one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
# print(one_list.shape)
return one_list
def get_feature_for_kmeans(feature_map, segments):
from sklearn.preprocessing import normalize
centroid_feature = centroid_using_superpixels(feature_map, segments)
avg_feature = avg_using_superpixels(feature_map, segments)
x = np.concatenate((centroid_feature, avg_feature), axis=-1)
x = normalize(x)
return x
def combine_normlime_and_lime(lime_weights, g_weights):
pred_labels = lime_weights.keys()
combined_weights = {y: [] for y in pred_labels}
for y in pred_labels:
normlized_lime_weights_y = lime_weights[y]
lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
normlized_g_weight_y = g_weights[y]
normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
combined_weights[y] = [
(seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
for seg_k in lime_weights_dict.keys()
]
combined_weights[y] = sorted(combined_weights[y],
key=lambda x: np.abs(x[1]), reverse=True)
return combined_weights
def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
import matplotlib.pyplot as plt
if isinstance(data_, str):
if algorithm_name == 'cam':
f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
else:
f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
plt.savefig(
os.path.join(save_outdir, f_out)
)
else:
n = 0
if algorithm_name == 'cam':
f_out = f'cam-{n}.png'
else:
f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
while os.path.exists(
os.path.join(save_outdir, f_out)
):
n += 1
if algorithm_name == 'cam':
f_out = f'cam-{n}.png'
else:
f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
continue
plt.savefig(
os.path.join(
save_outdir, f_out
)
)
"""
Contains abstract functionality for learning locally linear sparse model.
"""
from __future__ import print_function
import numpy as np
import scipy as sp
import sklearn
import sklearn.preprocessing
from skimage.color import gray2rgb
from sklearn.linear_model import Ridge, lars_path
from sklearn.utils import check_random_state
import copy
from functools import partial
from skimage.segmentation import quickshift
from skimage.measure import regionprops
class LimeBase(object):
"""Class for learning a locally linear sparse model from perturbed data"""
def __init__(self,
kernel_fn,
verbose=False,
random_state=None):
"""Init function
Args:
kernel_fn: function that transforms an array of distances into an
array of proximity values (floats).
verbose: if true, print local prediction values from linear model.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
self.kernel_fn = kernel_fn
self.verbose = verbose
self.random_state = check_random_state(random_state)
@staticmethod
def generate_lars_path(weighted_data, weighted_labels):
"""Generates the lars path for weighted data.
Args:
weighted_data: data that has been weighted by kernel
weighted_label: labels, weighted by kernel
Returns:
(alphas, coefs), both are arrays corresponding to the
regularization parameter and coefficients, respectively
"""
x_vector = weighted_data
alphas, _, coefs = lars_path(x_vector,
weighted_labels,
method='lasso',
verbose=False)
return alphas, coefs
def forward_selection(self, data, labels, weights, num_features):
"""Iteratively adds features to the model"""
clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
used_features = []
for _ in range(min(num_features, data.shape[1])):
max_ = -100000000
best = 0
for feature in range(data.shape[1]):
if feature in used_features:
continue
clf.fit(data[:, used_features + [feature]], labels,
sample_weight=weights)
score = clf.score(data[:, used_features + [feature]],
labels,
sample_weight=weights)
if score > max_:
best = feature
max_ = score
used_features.append(best)
return np.array(used_features)
def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see explain_instance_with_data to
understand the parameters."""
if method == 'none':
return np.array(range(data.shape[1]))
elif method == 'forward_selection':
return self.forward_selection(data, labels, weights, num_features)
elif method == 'highest_weights':
clf = Ridge(alpha=0.01, fit_intercept=True,
random_state=self.random_state)
clf.fit(data, labels, sample_weight=weights)
coef = clf.coef_
if sp.sparse.issparse(data):
coef = sp.sparse.csr_matrix(clf.coef_)
weighted_data = coef.multiply(data[0])
# Note: most efficient to slice the data before reversing
sdata = len(weighted_data.data)
argsort_data = np.abs(weighted_data.data).argsort()
# Edge case where data is more sparse than requested number of feature importances
# In that case, we just pad with zero-valued features
if sdata < num_features:
nnz_indexes = argsort_data[::-1]
indices = weighted_data.indices[nnz_indexes]
num_to_pad = num_features - sdata
indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
indices_set = set(indices)
pad_counter = 0
for i in range(data.shape[1]):
if i not in indices_set:
indices[pad_counter + sdata] = i
pad_counter += 1
if pad_counter >= num_to_pad:
break
else:
nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
indices = weighted_data.indices[nnz_indexes]
return indices
else:
weighted_data = coef * data[0]
feature_weights = sorted(
zip(range(data.shape[1]), weighted_data),
key=lambda x: np.abs(x[1]),
reverse=True)
return np.array([x[0] for x in feature_weights[:num_features]])
elif method == 'lasso_path':
weighted_data = ((data - np.average(data, axis=0, weights=weights))
* np.sqrt(weights[:, np.newaxis]))
weighted_labels = ((labels - np.average(labels, weights=weights))
* np.sqrt(weights))
nonzero = range(weighted_data.shape[1])
_, coefs = self.generate_lars_path(weighted_data,
weighted_labels)
for i in range(len(coefs.T) - 1, 0, -1):
nonzero = coefs.T[i].nonzero()[0]
if len(nonzero) <= num_features:
break
used_features = nonzero
return used_features
elif method == 'auto':
if num_features <= 6:
n_method = 'forward_selection'
else:
n_method = 'highest_weights'
return self.feature_selection(data, labels, weights,
num_features, n_method)
def explain_instance_with_data(self,
neighborhood_data,
neighborhood_labels,
distances,
label,
num_features,
feature_selection='auto',
model_regressor=None):
"""Takes perturbed data, labels and distances, returns explanation.
Args:
neighborhood_data: perturbed data, 2d array. first element is
assumed to be the original data point.
neighborhood_labels: corresponding perturbed labels. should have as
many columns as the number of possible labels.
distances: distances to original data point.
label: label for which we want an explanation
num_features: maximum number of features in explanation
feature_selection: how to select num_features. options are:
'forward_selection': iteratively add features to the model.
This is costly when num_features is high
'highest_weights': selects the features that have the highest
product of absolute weight * original data point when
learning with all the features
'lasso_path': chooses features based on the lasso
regularization path
'none': uses all features, ignores num_features
'auto': uses forward_selection if num_features <= 6, and
'highest_weights' otherwise.
model_regressor: sklearn regressor to use in explanation.
Defaults to Ridge regression if None. Must have
model_regressor.coef_ and 'sample_weight' as a parameter
to model_regressor.fit()
Returns:
(intercept, exp, score, local_pred):
intercept is a float.
exp is a sorted list of tuples, where each tuple (x,y) corresponds
to the feature id (x) and the local weight (y). The list is sorted
by decreasing absolute value of y.
score is the R^2 value of the returned explanation
local_pred is the prediction of the explanation model on the original instance
"""
weights = self.kernel_fn(distances)
labels_column = neighborhood_labels[:, label]
used_features = self.feature_selection(neighborhood_data,
labels_column,
weights,
num_features,
feature_selection)
if model_regressor is None:
model_regressor = Ridge(alpha=1, fit_intercept=True,
random_state=self.random_state)
easy_model = model_regressor
easy_model.fit(neighborhood_data[:, used_features],
labels_column, sample_weight=weights)
prediction_score = easy_model.score(
neighborhood_data[:, used_features],
labels_column, sample_weight=weights)
local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
if self.verbose:
print('Intercept', easy_model.intercept_)
print('Prediction_local', local_pred,)
print('Right:', neighborhood_labels[0, label])
return (easy_model.intercept_,
sorted(zip(used_features, easy_model.coef_),
key=lambda x: np.abs(x[1]), reverse=True),
prediction_score, local_pred)
class ImageExplanation(object):
def __init__(self, image, segments):
"""Init function.
Args:
image: 3d numpy array
segments: 2d numpy array, with the output from skimage.segmentation
"""
self.image = image
self.segments = segments
self.intercept = {}
self.local_exp = {}
self.local_pred = None
def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
num_features=5, min_weight=0.):
"""Init function.
Args:
label: label to explain
positive_only: if True, only take superpixels that positively contribute to
the prediction of the label.
negative_only: if True, only take superpixels that negatively contribute to
the prediction of the label. If false, and so is positive_only, then both
negativey and positively contributions will be taken.
Both can't be True at the same time
hide_rest: if True, make the non-explanation part of the return
image gray
num_features: number of superpixels to include in explanation
min_weight: minimum weight of the superpixels to include in explanation
Returns:
(image, mask), where image is a 3d numpy array and mask is a 2d
numpy array that can be used with
skimage.segmentation.mark_boundaries
"""
if label not in self.local_exp:
raise KeyError('Label not in explanation')
if positive_only & negative_only:
raise ValueError("Positive_only and negative_only cannot be true at the same time.")
segments = self.segments
image = self.image
exp = self.local_exp[label]
mask = np.zeros(segments.shape, segments.dtype)
if hide_rest:
temp = np.zeros(self.image.shape)
else:
temp = self.image.copy()
if positive_only:
fs = [x[0] for x in exp
if x[1] > 0 and x[1] > min_weight][:num_features]
if negative_only:
fs = [x[0] for x in exp
if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
if positive_only or negative_only:
for f in fs:
temp[segments == f] = image[segments == f].copy()
mask[segments == f] = 1
return temp, mask
else:
for f, w in exp[:num_features]:
if np.abs(w) < min_weight:
continue
c = 0 if w < 0 else 1
mask[segments == f] = -1 if w < 0 else 1
temp[segments == f] = image[segments == f].copy()
temp[segments == f, c] = np.max(image)
return temp, mask
def get_rendered_image(self, label, min_weight=0.005):
"""
Args:
label: label to explain
min_weight:
Returns:
image, is a 3d numpy array
"""
if label not in self.local_exp:
raise KeyError('Label not in explanation')
from matplotlib import cm
segments = self.segments
image = self.image
exp = self.local_exp[label]
temp = np.zeros_like(image)
weight_max = abs(exp[0][1])
exp = [(f, w/weight_max) for f, w in exp]
exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last.
cmaps = cm.get_cmap('Spectral')
# sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
colors = cmaps(np.linspace(0, 1, len(exp)))
colors = colors[:, :3]
for i, (f, w) in enumerate(exp):
if np.abs(w) < min_weight:
continue
temp[segments == f] = image[segments == f].copy()
temp[segments == f] = colors[i] * 255
return temp
class LimeImageExplainer(object):
"""Explains predictions on Image (i.e. matrix) data.
For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data. For categorical features, perturb by
sampling according to the training distribution, and making a binary
feature that is 1 when the value is the same as the instance being
explained."""
def __init__(self, kernel_width=.25, kernel=None, verbose=False,
feature_selection='auto', random_state=None):
"""Init function.
Args:
kernel_width: kernel width for the exponential kernel.
If None, defaults to sqrt(number of columns) * 0.75.
kernel: similarity kernel that takes euclidean distances and kernel
width as input and outputs weights in (0,1). If None, defaults to
an exponential kernel.
verbose: if true, print local prediction values from linear model
feature_selection: feature selection method. can be
'forward_selection', 'lasso_path', 'none' or 'auto'.
See function 'explain_instance_with_data' in lime_base.py for
details on what each of the options does.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
kernel_width = float(kernel_width)
if kernel is None:
def kernel(d, kernel_width):
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
kernel_fn = partial(kernel, kernel_width=kernel_width)
self.random_state = check_random_state(random_state)
self.feature_selection = feature_selection
self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
def explain_instance(self, image, classifier_fn, labels=(1,),
hide_color=None,
num_features=100000, num_samples=1000,
batch_size=10,
distance_metric='cosine',
model_regressor=None
):
"""Generates explanations for a prediction.
First, we generate neighborhood data by randomly perturbing features
from the instance (see __data_inverse). We then learn locally weighted
linear models on this neighborhood data to explain each of the classes
in an interpretable way (see lime_base.py).
Args:
image: 3 dimension RGB image. If this is only two dimensional,
we will assume it's a grayscale image and call gray2rgb.
classifier_fn: classifier prediction probability function, which
takes a numpy array and outputs prediction probabilities. For
ScikitClassifiers , this is classifier.predict_proba.
labels: iterable with labels to be explained.
hide_color: TODO
num_features: maximum number of features present in explanation
num_samples: size of the neighborhood to learn the linear model
batch_size: TODO
distance_metric: the distance metric to use for weights.
model_regressor: sklearn regressor to use in explanation. Defaults
to Ridge regression in LimeBase. Must have model_regressor.coef_
and 'sample_weight' as a parameter to model_regressor.fit()
Returns:
An ImageExplanation object (see lime_image.py) with the corresponding
explanations.
"""
if len(image.shape) == 2:
image = gray2rgb(image)
try:
segments = quickshift(image, sigma=1)
except ValueError as e:
raise e
self.segments = segments
fudged_image = image.copy()
if hide_color is None:
# if no hide_color, use the mean
for x in np.unique(segments):
mx = np.mean(image[segments == x], axis=0)
fudged_image[segments == x] = mx
elif hide_color == 'avg_from_neighbor':
from scipy.spatial.distance import cdist
n_features = np.unique(segments).shape[0]
regions = regionprops(segments + 1)
centroids = np.zeros((n_features, 2))
for i, x in enumerate(regions):
centroids[i] = np.array(x.centroid)
d = cdist(centroids, centroids, 'sqeuclidean')
for x in np.unique(segments):
# print(np.argmin(d[x]))
a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
mx = np.mean(np.concatenate(a), axis=0)
fudged_image[segments == x] = mx
else:
fudged_image[:] = 0
top = labels
data, labels = self.data_labels(image, fudged_image, segments,
classifier_fn, num_samples,
batch_size=batch_size)
distances = sklearn.metrics.pairwise_distances(
data,
data[0].reshape(1, -1),
metric=distance_metric
).ravel()
ret_exp = ImageExplanation(image, segments)
for label in top:
(ret_exp.intercept[label],
ret_exp.local_exp[label],
ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
data, labels, distances, label, num_features,
model_regressor=model_regressor,
feature_selection=self.feature_selection)
return ret_exp
def data_labels(self,
image,
fudged_image,
segments,
classifier_fn,
num_samples,
batch_size=10):
"""Generates images and predictions in the neighborhood of this image.
Args:
image: 3d numpy array, the image
fudged_image: 3d numpy array, image to replace original image when
superpixel is turned off
segments: segmentation of the image
classifier_fn: function that takes a list of images and returns a
matrix of prediction probabilities
num_samples: size of the neighborhood to learn the linear model
batch_size: classifier_fn will be called on batches of this size.
Returns:
A tuple (data, labels), where:
data: dense num_samples * num_superpixels
labels: prediction probabilities matrix
"""
n_features = np.unique(segments).shape[0]
data = self.random_state.randint(0, 2, num_samples * n_features) \
.reshape((num_samples, n_features))
labels = []
data[0, :] = 1
imgs = []
for row in data:
temp = copy.deepcopy(image)
zeros = np.where(row == 0)[0]
mask = np.zeros(segments.shape).astype(bool)
for z in zeros:
mask[segments == z] = True
temp[mask] = fudged_image[mask]
imgs.append(temp)
if len(imgs) == batch_size:
preds = classifier_fn(np.array(imgs))
labels.extend(preds)
imgs = []
if len(imgs) > 0:
preds = classifier_fn(np.array(imgs))
labels.extend(preds)
return data, np.array(labels)
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import copy
import os.path as osp
import numpy as np
from .core.explanation import Explanation
def visualize(img_file,
model,
explanation_type='lime',
num_samples=3000,
batch_size=50,
save_dir='./'):
model.arrange_transforms(
transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0]
img = np.around(img).astype('uint8')
img = np.expand_dims(img, axis=0)
explaier = None
if explanation_type == 'lime':
explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
else:
raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
img_name = osp.splitext(osp.split(img_file)[-1])[0]
explaier.explain(img, save_dir=save_dir)
def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
def predict_func(image):
image = image.astype('float32')
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
return out[0]
explaier = Explanation('lime',
predict_func,
num_samples=num_samples,
batch_size=batch_size)
return explaier
\ No newline at end of file
...@@ -120,6 +120,7 @@ class ResNet(object): ...@@ -120,6 +120,7 @@ class ResNet(object):
self.num_classes = num_classes self.num_classes = num_classes
self.lr_mult_list = lr_mult_list self.lr_mult_list = lr_mult_list
self.curr_stage = 0 self.curr_stage = 0
self.features = []
def _conv_offset(self, def _conv_offset(self,
input, input,
...@@ -474,7 +475,9 @@ class ResNet(object): ...@@ -474,7 +475,9 @@ class ResNet(object):
size=self.num_classes, size=self.num_classes,
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv))) initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out self.features.append(out)
# out.persistable=True
return out, self.features
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)]) for idx, feat in enumerate(res_endpoints)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册